Compare commits
8 Commits
fix/695
...
fix/689-au
| Author | SHA1 | Date | |
|---|---|---|---|
| 76a886334b | |||
| e1abecbc54 | |||
| b3f5a2f21c | |||
| e176fadef5 | |||
| 7ca2ebe6b5 | |||
| e9d2cb5e56 | |||
| 990676fb02 | |||
| 3ad934febd |
74
docs/visual-evidence-689.md
Normal file
74
docs/visual-evidence-689.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Visual Evidence — Gemma 4 Multimodal Scene Description Generator
|
||||
|
||||
## Test Image: Coffee Beans (Macro Photo)
|
||||
|
||||
### Gemma 4 Vision Analysis (via Ollama)
|
||||
|
||||
**Model:** gemma4:latest (8B, Q4_K_M)
|
||||
**Input:** sample_photo.jpg (46KB JPEG)
|
||||
|
||||
**Structured Output (JSONL):**
|
||||
```json
|
||||
{
|
||||
"mood": "dark",
|
||||
"colors": ["dark brown", "espresso", "black"],
|
||||
"composition": "close-up",
|
||||
"camera": "static",
|
||||
"lighting": "soft",
|
||||
"description": "An extreme close-up shot captures a dense pile of roasted coffee beans. The beans are a uniform, deep dark brown and appear slightly oily, filling the entire frame. The focus emphasizes the rich texture and individual shapes of the beans."
|
||||
}
|
||||
```
|
||||
|
||||
### Hermes Vision Analysis (Cross-Validation)
|
||||
|
||||
**Scene ID:** COFFEE_MACRO_001
|
||||
**Mood:** Warm, aromatic, and comforting
|
||||
**Dominant Colors:** Deep umber, burnt sienna, espresso black, mahogany
|
||||
**Composition:** Full-frame fill, centrally weighted
|
||||
**Camera:** High-angle, close-up (Macro)
|
||||
**Lighting:** Soft, diffused top-lighting
|
||||
|
||||
## Test Image: Abstract Geometric Composition
|
||||
|
||||
### Gemma 4 Vision Analysis
|
||||
|
||||
**Input:** scene1.jpg (10KB, PIL-generated)
|
||||
|
||||
**Structured Output (JSONL):**
|
||||
```json
|
||||
{
|
||||
"mood": "energetic",
|
||||
"colors": ["deep blue", "yellow", "coral"],
|
||||
"composition": "wide-shot",
|
||||
"camera": "static",
|
||||
"lighting": "artificial",
|
||||
"description": "This is an abstract graphic composition set against a solid, deep blue background. A bright yellow square is placed in the upper left quadrant, while a large, solid coral-colored circle occupies the lower right quadrant. The geometric shapes create a high-contrast, minimalist visual balance."
|
||||
}
|
||||
```
|
||||
|
||||
## Verification Summary
|
||||
|
||||
| Test | Status | Details |
|
||||
|------|--------|---------|
|
||||
| Model detection | ✅ PASS | `gemma4:latest` auto-detected |
|
||||
| Image scanning | ✅ PASS | 2 images found recursively |
|
||||
| Vision analysis | ✅ PASS | Both images described accurately |
|
||||
| JSON parsing | ✅ PASS | Structured output with all fields |
|
||||
| Training format | ✅ PASS | JSONL with source, model, timestamp |
|
||||
| ShareGPT format | ⚠️ PARTIAL | Works but needs retry on rate limit |
|
||||
|
||||
## Running the Generator
|
||||
|
||||
```bash
|
||||
# Check model availability
|
||||
python scripts/generate_scene_descriptions.py --check-model
|
||||
|
||||
# Generate scene descriptions from assets
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --output training-data/scene-descriptions-auto.jsonl
|
||||
|
||||
# Limit to 10 files with specific model
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --model gemma4:latest --limit 10
|
||||
|
||||
# ShareGPT format for training pipeline
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --format sharegpt
|
||||
```
|
||||
409
scripts/generate_scene_descriptions.py
Normal file
409
scripts/generate_scene_descriptions.py
Normal file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-generate scene descriptions from image/video assets.
|
||||
|
||||
Scans a directory for media files, generates scene descriptions using
|
||||
a local vision model (Ollama), and outputs training pairs in JSONL format.
|
||||
|
||||
Supports Gemma 4 multimodal vision via Ollama. Falls back gracefully when
|
||||
models are unavailable.
|
||||
|
||||
Usage:
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --output training-data/scene-descriptions-auto.jsonl
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --model gemma4:latest --limit 50
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --format sharegpt
|
||||
python scripts/generate_scene_descriptions.py --dry-run # List files without generating
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --check-model # Verify model availability
|
||||
|
||||
Ref: timmy-config#689
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Supported media extensions
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
VIDEO_EXTS = {".mp4", ".webm", ".mov", ".avi", ".mkv"}
|
||||
ALL_EXTS = IMAGE_EXTS | VIDEO_EXTS
|
||||
|
||||
# File size limit (50MB) — prevents unbounded memory usage on large images
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Vision models in preference order (best first)
|
||||
VISION_MODELS = [
|
||||
"gemma4:latest", # Gemma 4 — multimodal vision (8B, Q4_K_M)
|
||||
"gemma3:12b", # Gemma 3 — fallback vision
|
||||
"llava:latest", # LLaVA — generic vision
|
||||
"llava-phi3:latest", # LLaVA-Phi3 — lightweight vision
|
||||
]
|
||||
|
||||
# Vision model prompt template (structured JSON output)
|
||||
SCENE_PROMPT = """Describe this image for a visual scene database. Output ONLY valid JSON (no markdown, no explanation):
|
||||
{
|
||||
"mood": "one of: calm, energetic, dark, warm, cool, chaotic, serene, tense, joyful, melancholic",
|
||||
"colors": ["dominant color 1", "dominant color 2", "dominant color 3"],
|
||||
"composition": "one of: close-up, wide-shot, medium-shot, low-angle, high-angle, bird-eye, profile, over-shoulder",
|
||||
"camera": "one of: static, slow-pan, tracking, handheld, crane, dolly, steady, locked-off",
|
||||
"lighting": "one of: natural, artificial, mixed, dramatic, soft, harsh, backlit",
|
||||
"description": "2-3 sentence visual description of the scene"
|
||||
}
|
||||
|
||||
Be specific. Describe what you see, not what you imagine."""
|
||||
|
||||
# ShareGPT format prompt (for training pipeline integration)
|
||||
SHAREGPT_SCENE_PROMPT = """Analyze this image and describe the visual scene. Include mood, dominant colors, composition, camera angle, lighting, and a vivid 2-3 sentence description."""
|
||||
|
||||
|
||||
def check_model_available(model: str, ollama_url: str = "http://localhost:11434") -> bool:
|
||||
"""Check if a model is available in Ollama."""
|
||||
try:
|
||||
req = urllib.request.Request(f"{ollama_url}/api/tags")
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
data = json.loads(resp.read())
|
||||
available = [m["name"] for m in data.get("models", [])]
|
||||
return model in available
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def auto_detect_model(ollama_url: str = "http://localhost:11434") -> Optional[str]:
|
||||
"""Auto-detect the best available vision model."""
|
||||
for model in VISION_MODELS:
|
||||
if check_model_available(model, ollama_url):
|
||||
print(f"Auto-detected vision model: {model}", file=sys.stderr)
|
||||
return model
|
||||
return None
|
||||
|
||||
|
||||
def scan_media(input_dir: str) -> list[Path]:
|
||||
"""Scan directory for media files recursively."""
|
||||
media_files = []
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_dir} does not exist", file=sys.stderr)
|
||||
return media_files
|
||||
|
||||
for ext in sorted(ALL_EXTS):
|
||||
media_files.extend(input_path.rglob(f"*{ext}"))
|
||||
media_files.extend(input_path.rglob(f"*{ext.upper()}"))
|
||||
|
||||
return sorted(set(media_files))
|
||||
|
||||
|
||||
def extract_video_frame(video_path: Path, output_path: Path) -> bool:
|
||||
"""Extract a representative frame from a video using ffmpeg."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
# FIX #3: Seek 2s in before grabbing frame — avoids black/title frames
|
||||
["ffmpeg", "-ss", "2", "-i", str(video_path), "-vframes", "1",
|
||||
"-q:v", "2", str(output_path), "-y"],
|
||||
capture_output=True, timeout=30,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
print(f" ffmpeg stderr: {result.stderr.decode(errors='replace')[:200]}", file=sys.stderr)
|
||||
return output_path.exists() and output_path.stat().st_size > 0
|
||||
except FileNotFoundError:
|
||||
print(" ffmpeg not found — skipping video frame extraction", file=sys.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ffmpeg error: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def describe_image(
|
||||
image_path: Path,
|
||||
model: str = "gemma4:latest",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
max_retries: int = 2,
|
||||
) -> Optional[dict]:
|
||||
"""Generate scene description using Ollama vision model with retry."""
|
||||
# FIX #1: Check file size before reading into memory
|
||||
if image_path.stat().st_size > MAX_FILE_SIZE:
|
||||
print(f" Skipping {image_path.name}: exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit", file=sys.stderr)
|
||||
return None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{ollama_url}/api/generate",
|
||||
data=json.dumps({
|
||||
"model": model,
|
||||
"prompt": SCENE_PROMPT,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 1024}
|
||||
}).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
data = json.loads(resp.read())
|
||||
response_text = data.get("response", "")
|
||||
|
||||
# Parse JSON from response (handle both complete and truncated JSON)
|
||||
json_match = re.search(r"\{[\s\S]*\}", response_text)
|
||||
if not json_match:
|
||||
# Try to find opening brace for truncated JSON
|
||||
brace_match = re.search(r"\{", response_text)
|
||||
if brace_match:
|
||||
json_match = brace_match
|
||||
|
||||
if json_match:
|
||||
raw_json = json_match.group() if hasattr(json_match, 'group') else response_text[json_match.start():]
|
||||
# Try strict parse first
|
||||
try:
|
||||
parsed = json.loads(raw_json)
|
||||
required = ["mood", "colors", "composition", "camera", "description"]
|
||||
if all(k in parsed for k in required) and parsed.get("description"):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
# Attempt repair: extract fields from truncated JSON
|
||||
repaired = {}
|
||||
for field in ["mood", "colors", "composition", "camera", "lighting", "description"]:
|
||||
pat = rf'"\s*{field}"\s*:\s*"([^"]*)"'
|
||||
m = re.search(pat, response_text)
|
||||
if m:
|
||||
repaired[field] = m.group(1)
|
||||
elif field == "colors":
|
||||
colors_match = re.search(r'"colors"\s*:\s*\[([^\]]*)\]', response_text)
|
||||
if colors_match:
|
||||
repaired[field] = [c.strip().strip('"') for c in colors_match.group(1).split(",") if c.strip()]
|
||||
else:
|
||||
repaired[field] = []
|
||||
else:
|
||||
repaired[field] = "unknown"
|
||||
if repaired.get("description") or repaired.get("mood") != "unknown":
|
||||
return repaired
|
||||
|
||||
# Final fallback: natural language response
|
||||
clean = re.sub(r"[*_`#]", "", response_text).strip()
|
||||
clean = re.sub(r"\n{3,}", "\n\n", clean)
|
||||
return {
|
||||
"description": clean[:500] if clean else response_text[:500],
|
||||
"mood": "unknown",
|
||||
"colors": [],
|
||||
"composition": "unknown",
|
||||
"camera": "unknown",
|
||||
"lighting": "unknown"
|
||||
}
|
||||
|
||||
except (urllib.error.URLError, TimeoutError) as e:
|
||||
if attempt < max_retries:
|
||||
wait = 2 ** attempt
|
||||
print(f" Retry {attempt + 1}/{max_retries} after {wait}s: {e}", file=sys.stderr)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
print(f" Error describing {image_path.name}: {e}", file=sys.stderr)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" Error describing {image_path.name}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def describe_image_sharegpt(
|
||||
image_path: Path,
|
||||
model: str = "gemma4:latest",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
max_retries: int = 2,
|
||||
) -> Optional[str]:
|
||||
"""Generate scene description in natural language for ShareGPT format."""
|
||||
# FIX #1: Check file size before reading into memory
|
||||
if image_path.stat().st_size > MAX_FILE_SIZE:
|
||||
print(f" Skipping {image_path.name}: exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit", file=sys.stderr)
|
||||
return None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{ollama_url}/api/generate",
|
||||
data=json.dumps({
|
||||
"model": model,
|
||||
"prompt": SHAREGPT_SCENE_PROMPT,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.5, "num_predict": 256}
|
||||
}).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
data = json.loads(resp.read())
|
||||
return data.get("response", "").strip()
|
||||
|
||||
except (urllib.error.URLError, TimeoutError) as e:
|
||||
if attempt < max_retries:
|
||||
time.sleep(2 ** attempt)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_training_pairs(
|
||||
media_files: list[Path],
|
||||
model: str,
|
||||
ollama_url: str,
|
||||
limit: int = 0,
|
||||
dry_run: bool = False,
|
||||
output_format: str = "jsonl",
|
||||
) -> list[dict]:
|
||||
"""Generate training pairs from media files."""
|
||||
pairs = []
|
||||
files = media_files[:limit] if limit > 0 else media_files
|
||||
|
||||
print(f"Processing {len(files)} files with model {model}...", file=sys.stderr)
|
||||
|
||||
for i, media_path in enumerate(files):
|
||||
print(f" [{i + 1}/{len(files)}] {media_path.name}...", file=sys.stderr, end=" ", flush=True)
|
||||
|
||||
if dry_run:
|
||||
print("(dry run)", file=sys.stderr)
|
||||
pairs.append({"source": str(media_path), "status": "dry-run"})
|
||||
continue
|
||||
|
||||
is_video = media_path.suffix.lower() in VIDEO_EXTS
|
||||
work_path = media_path
|
||||
|
||||
if is_video:
|
||||
frame_path = media_path.with_suffix(".frame.jpg")
|
||||
if extract_video_frame(media_path, frame_path):
|
||||
work_path = frame_path
|
||||
else:
|
||||
print("SKIP (frame extraction failed)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
try:
|
||||
if output_format == "sharegpt":
|
||||
# ShareGPT format for training pipeline
|
||||
description = describe_image_sharegpt(work_path, model, ollama_url)
|
||||
if description:
|
||||
pair = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": f"<image>\n{SHAREGPT_SCENE_PROMPT}"},
|
||||
{"from": "gpt", "value": description}
|
||||
],
|
||||
"source": str(media_path),
|
||||
"media_type": "video" if is_video else "image",
|
||||
"model": model,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
pairs.append(pair)
|
||||
print("OK", file=sys.stderr)
|
||||
else:
|
||||
print("FAIL", file=sys.stderr)
|
||||
else:
|
||||
# Structured JSONL format
|
||||
description = describe_image(work_path, model, ollama_url)
|
||||
if description:
|
||||
pair = {
|
||||
"source": str(media_path),
|
||||
"media_type": "video" if is_video else "image",
|
||||
"description": description,
|
||||
"model": model,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
pairs.append(pair)
|
||||
print("OK", file=sys.stderr)
|
||||
else:
|
||||
print("FAIL", file=sys.stderr)
|
||||
finally:
|
||||
# FIX #6: Cleanup temp frame in try/finally — survives crashes
|
||||
if is_video and work_path != media_path:
|
||||
try:
|
||||
work_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Small delay between files (reduced from 0.5s — Ollama is local)
|
||||
time.sleep(0.1)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Auto-generate scene descriptions from media assets using vision AI"
|
||||
)
|
||||
parser.add_argument("--input", "-i", default="", help="Input directory with media files")
|
||||
parser.add_argument("--output", "-o", default="training-data/scene-descriptions-auto.jsonl")
|
||||
parser.add_argument("--model", "-m", default=None, help="Ollama model name (auto-detects best available if omitted)")
|
||||
parser.add_argument("--ollama-url", default="http://localhost:11434")
|
||||
parser.add_argument("--limit", "-l", type=int, default=0, help="Max files to process (0=all)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="List files without generating")
|
||||
parser.add_argument("--check-model", action="store_true", help="Check model availability and exit")
|
||||
parser.add_argument("--format", choices=["jsonl", "sharegpt"], default="jsonl",
|
||||
help="Output format: jsonl (structured) or sharegpt (training pipeline)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Model detection
|
||||
if args.check_model:
|
||||
if args.model:
|
||||
available = check_model_available(args.model, args.ollama_url)
|
||||
print(f"Model '{args.model}': {'✅ available' if available else '❌ not found'}")
|
||||
else:
|
||||
model = auto_detect_model(args.ollama_url)
|
||||
if model:
|
||||
print(f"✅ Best available: {model}")
|
||||
else:
|
||||
print("❌ No vision models found in Ollama — install one with: ollama pull gemma4:latest")
|
||||
sys.exit(0)
|
||||
|
||||
# Auto-detect model if not specified
|
||||
model = args.model
|
||||
if not model:
|
||||
model = auto_detect_model(args.ollama_url)
|
||||
if not model:
|
||||
# Fall back to best default even if not installed — let Ollama handle the error
|
||||
model = "gemma4:latest"
|
||||
print(f"Warning: No vision models detected. Falling back to {model}", file=sys.stderr)
|
||||
|
||||
# Validate input
|
||||
if not args.input:
|
||||
print("Error: --input is required (unless using --check-model)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Scan and process
|
||||
media_files = scan_media(args.input)
|
||||
print(f"Found {len(media_files)} media files", file=sys.stderr)
|
||||
|
||||
if not media_files:
|
||||
print("No media files found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
pairs = generate_training_pairs(
|
||||
media_files, model, args.ollama_url,
|
||||
args.limit, args.dry_run, args.format
|
||||
)
|
||||
|
||||
# Write output
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
for pair in pairs:
|
||||
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"\nWrote {len(pairs)} pairs to {output_path}", file=sys.stderr)
|
||||
|
||||
# Summary
|
||||
success = len([p for p in pairs if "description" in p or "conversations" in p])
|
||||
failed = len(pairs) - success
|
||||
if failed > 0:
|
||||
print(f" ⚠️ {failed} files failed", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
333
tests/test_scene_descriptions.py
Normal file
333
tests/test_scene_descriptions.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for generate_scene_descriptions.py
|
||||
|
||||
Tests the scene description generation pipeline including:
|
||||
- Media file scanning
|
||||
- Model detection
|
||||
- JSON parsing from vision responses
|
||||
- Output format validation
|
||||
|
||||
Ref: timmy-config#689
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Add scripts to path for import
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
|
||||
|
||||
from generate_scene_descriptions import (
|
||||
IMAGE_EXTS,
|
||||
VIDEO_EXTS,
|
||||
ALL_EXTS,
|
||||
VISION_MODELS,
|
||||
auto_detect_model,
|
||||
check_model_available,
|
||||
scan_media,
|
||||
extract_video_frame,
|
||||
)
|
||||
|
||||
|
||||
class TestMediaScanning(unittest.TestCase):
|
||||
"""Test media file scanning."""
|
||||
|
||||
def test_scan_empty_directory(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_scan_nonexistent_directory(self):
|
||||
result = scan_media("/nonexistent/path/that/does/not/exist")
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_scan_with_images(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create test files
|
||||
for ext in [".jpg", ".png", ".webp"]:
|
||||
(Path(tmpdir) / f"test{ext}").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 3)
|
||||
|
||||
def test_scan_recursive(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
subdir = Path(tmpdir) / "sub" / "dir"
|
||||
subdir.mkdir(parents=True)
|
||||
(subdir / "deep.jpg").touch()
|
||||
(Path(tmpdir) / "top.png").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_scan_ignores_unsupported(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "image.jpg").touch()
|
||||
(Path(tmpdir) / "document.pdf").touch()
|
||||
(Path(tmpdir) / "script.py").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_scan_sorted_output(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for name in ["z.jpg", "a.png", "m.webp"]:
|
||||
(Path(tmpdir) / name).touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
names = [p.name for p in result]
|
||||
self.assertEqual(names, sorted(names))
|
||||
|
||||
|
||||
class TestModelDetection(unittest.TestCase):
|
||||
"""Test model availability detection."""
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_check_model_available(self, mock_urlopen):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps({
|
||||
"models": [{"name": "gemma4:latest"}]
|
||||
}).encode()
|
||||
mock_urlopen.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_urlopen.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = check_model_available("gemma4:latest")
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_check_model_not_available(self, mock_urlopen):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps({
|
||||
"models": [{"name": "llama2:7b"}]
|
||||
}).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = check_model_available("gemma4:latest")
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_prefers_gemma4(self, mock_check):
|
||||
def side_effect(model, url):
|
||||
return model == "gemma4:latest"
|
||||
mock_check.side_effect = side_effect
|
||||
|
||||
result = auto_detect_model()
|
||||
self.assertEqual(result, "gemma4:latest")
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_falls_back(self, mock_check):
|
||||
def side_effect(model, url):
|
||||
return model == "llava:latest"
|
||||
mock_check.side_effect = side_effect
|
||||
|
||||
result = auto_detect_model()
|
||||
self.assertEqual(result, "llava:latest")
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_returns_none_when_no_models(self, mock_check):
|
||||
mock_check.return_value = False
|
||||
result = auto_detect_model()
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestConstants(unittest.TestCase):
|
||||
"""Test constant definitions."""
|
||||
|
||||
def test_image_extensions(self):
|
||||
self.assertIn(".jpg", IMAGE_EXTS)
|
||||
self.assertIn(".png", IMAGE_EXTS)
|
||||
self.assertIn(".webp", IMAGE_EXTS)
|
||||
|
||||
def test_video_extensions(self):
|
||||
self.assertIn(".mp4", VIDEO_EXTS)
|
||||
self.assertIn(".webm", VIDEO_EXTS)
|
||||
|
||||
def test_all_extensions_union(self):
|
||||
self.assertEqual(ALL_EXTS, IMAGE_EXTS | VIDEO_EXTS)
|
||||
|
||||
def test_vision_models_ordered(self):
|
||||
self.assertEqual(VISION_MODELS[0], "gemma4:latest")
|
||||
self.assertIn("llava:latest", VISION_MODELS)
|
||||
|
||||
|
||||
class TestVideoFrameExtraction(unittest.TestCase):
|
||||
"""Test video frame extraction."""
|
||||
|
||||
def test_extract_nonexistent_video(self):
|
||||
result = extract_video_frame(Path("/nonexistent.mp4"), Path("/tmp/frame.jpg"))
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDescribeImage(unittest.TestCase):
|
||||
"""Test describe_image() with mocked Ollama responses."""
|
||||
|
||||
def test_skips_oversized_file(self):
|
||||
"""Files exceeding MAX_FILE_SIZE should be skipped without API call."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * (51 * 1024 * 1024))
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_parses_valid_json_response(self, mock_urlopen):
|
||||
"""Valid JSON response should be parsed and returned."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {
|
||||
"response": '{"mood": "calm", "colors": ["blue", "white"], "composition": "wide-shot", "camera": "static", "lighting": "natural", "description": "A serene ocean scene."}'
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result["mood"], "calm")
|
||||
self.assertIn("lighting", result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_repair_truncated_json(self, mock_urlopen):
|
||||
"""Truncated JSON should be repaired with regex extraction."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {
|
||||
"response": '{"mood": "dark", "colors": ["red"], "composition": "close-up", "camera": "handheld", "lighting": "dramatic", "description": "A shadowy figure in a dimly lit alley'
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result["mood"], "dark")
|
||||
self.assertEqual(result["lighting"], "dramatic")
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_fallback_on_invalid_json(self, mock_urlopen):
|
||||
"""Completely invalid JSON response should still return a fallback."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {"response": "This is just plain text describing a beautiful sunset over mountains."}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("description", result)
|
||||
self.assertIn("lighting", result)
|
||||
|
||||
|
||||
class TestDescribeImageSharegpt(unittest.TestCase):
|
||||
"""Test describe_image_sharegpt() with mocked Ollama responses."""
|
||||
|
||||
def test_skips_oversized_file(self):
|
||||
"""Files exceeding MAX_FILE_SIZE should be skipped."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * (51 * 1024 * 1024))
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image_sharegpt(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_returns_natural_language(self, mock_urlopen):
|
||||
"""Should return the raw response text."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {"response": "A warm sunset over rolling hills with golden light."}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image_sharegpt(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("sunset", result)
|
||||
|
||||
|
||||
class TestGenerateTrainingPairs(unittest.TestCase):
|
||||
"""Test generate_training_pairs() orchestration."""
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image')
|
||||
def test_jsonl_output_format(self, mock_describe):
|
||||
"""JSONL format should produce structured description objects."""
|
||||
import generate_scene_descriptions
|
||||
mock_describe.return_value = {"mood": "calm", "description": "Test"}
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
output_format="jsonl"
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertIn("description", pairs[0])
|
||||
self.assertIn("generated_at", pairs[0])
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image_sharegpt')
|
||||
def test_sharegpt_output_format(self, mock_describe):
|
||||
"""ShareGPT format should produce conversation objects."""
|
||||
import generate_scene_descriptions
|
||||
mock_describe.return_value = "A description of the scene."
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
output_format="sharegpt"
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertIn("conversations", pairs[0])
|
||||
self.assertEqual(len(pairs[0]["conversations"]), 2)
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image')
|
||||
def test_dry_run_skips_api_calls(self, mock_describe):
|
||||
"""Dry run should not call describe_image."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
dry_run=True
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
mock_describe.assert_not_called()
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertEqual(pairs[0]["status"], "dry-run")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user