Files
timmy-config/tests/test_scene_descriptions.py

168 lines
5.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Tests for generate_scene_descriptions.py
Tests the scene description generation pipeline including:
- Media file scanning
- Model detection
- JSON parsing from vision responses
- Output format validation
Ref: timmy-config#689
"""
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
# Add scripts to path for import
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
from generate_scene_descriptions import (
IMAGE_EXTS,
VIDEO_EXTS,
ALL_EXTS,
VISION_MODELS,
auto_detect_model,
check_model_available,
scan_media,
extract_video_frame,
)
class TestMediaScanning(unittest.TestCase):
"""Test media file scanning."""
def test_scan_empty_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
result = scan_media(tmpdir)
self.assertEqual(result, [])
def test_scan_nonexistent_directory(self):
result = scan_media("/nonexistent/path/that/does/not/exist")
self.assertEqual(result, [])
def test_scan_with_images(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Create test files
for ext in [".jpg", ".png", ".webp"]:
(Path(tmpdir) / f"test{ext}").touch()
result = scan_media(tmpdir)
self.assertEqual(len(result), 3)
def test_scan_recursive(self):
with tempfile.TemporaryDirectory() as tmpdir:
subdir = Path(tmpdir) / "sub" / "dir"
subdir.mkdir(parents=True)
(subdir / "deep.jpg").touch()
(Path(tmpdir) / "top.png").touch()
result = scan_media(tmpdir)
self.assertEqual(len(result), 2)
def test_scan_ignores_unsupported(self):
with tempfile.TemporaryDirectory() as tmpdir:
(Path(tmpdir) / "image.jpg").touch()
(Path(tmpdir) / "document.pdf").touch()
(Path(tmpdir) / "script.py").touch()
result = scan_media(tmpdir)
self.assertEqual(len(result), 1)
def test_scan_sorted_output(self):
with tempfile.TemporaryDirectory() as tmpdir:
for name in ["z.jpg", "a.png", "m.webp"]:
(Path(tmpdir) / name).touch()
result = scan_media(tmpdir)
names = [p.name for p in result]
self.assertEqual(names, sorted(names))
class TestModelDetection(unittest.TestCase):
"""Test model availability detection."""
@patch('generate_scene_descriptions.urllib.request.urlopen')
def test_check_model_available(self, mock_urlopen):
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps({
"models": [{"name": "gemma4:latest"}]
}).encode()
mock_urlopen.return_value.__enter__ = MagicMock(return_value=mock_resp)
mock_urlopen.return_value.__exit__ = MagicMock(return_value=False)
mock_urlopen.return_value = mock_resp
result = check_model_available("gemma4:latest")
self.assertTrue(result)
@patch('generate_scene_descriptions.urllib.request.urlopen')
def test_check_model_not_available(self, mock_urlopen):
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps({
"models": [{"name": "llama2:7b"}]
}).encode()
mock_urlopen.return_value = mock_resp
result = check_model_available("gemma4:latest")
self.assertFalse(result)
@patch('generate_scene_descriptions.check_model_available')
def test_auto_detect_prefers_gemma4(self, mock_check):
def side_effect(model, url):
return model == "gemma4:latest"
mock_check.side_effect = side_effect
result = auto_detect_model()
self.assertEqual(result, "gemma4:latest")
@patch('generate_scene_descriptions.check_model_available')
def test_auto_detect_falls_back(self, mock_check):
def side_effect(model, url):
return model == "llava:latest"
mock_check.side_effect = side_effect
result = auto_detect_model()
self.assertEqual(result, "llava:latest")
@patch('generate_scene_descriptions.check_model_available')
def test_auto_detect_returns_none_when_no_models(self, mock_check):
mock_check.return_value = False
result = auto_detect_model()
self.assertIsNone(result)
class TestConstants(unittest.TestCase):
"""Test constant definitions."""
def test_image_extensions(self):
self.assertIn(".jpg", IMAGE_EXTS)
self.assertIn(".png", IMAGE_EXTS)
self.assertIn(".webp", IMAGE_EXTS)
def test_video_extensions(self):
self.assertIn(".mp4", VIDEO_EXTS)
self.assertIn(".webm", VIDEO_EXTS)
def test_all_extensions_union(self):
self.assertEqual(ALL_EXTS, IMAGE_EXTS | VIDEO_EXTS)
def test_vision_models_ordered(self):
self.assertEqual(VISION_MODELS[0], "gemma4:latest")
self.assertIn("llava:latest", VISION_MODELS)
class TestVideoFrameExtraction(unittest.TestCase):
"""Test video frame extraction."""
def test_extract_nonexistent_video(self):
result = extract_video_frame(Path("/nonexistent.mp4"), Path("/tmp/frame.jpg"))
self.assertFalse(result)
if __name__ == "__main__":
unittest.main()