Files
timmy-config/scripts/generate_scenes_from_media.py

287 lines
11 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
generate_scenes_from_media.py Auto-generate scene descriptions from image/video assets.
Scans a directory for images/videos, generates scene descriptions using
a vision model, and outputs as training pairs in JSONL format.
Usage:
python3 scripts/generate_scenes_from_media.py --assets ~/assets/ --output training-data/media-scenes.jsonl
python3 scripts/generate_scenes_from_media.py --assets ~/assets/ --model llava --dry-run
python3 scripts/generate_scenes_from_media.py --assets ~/assets/ --max 10 --json
"""
import argparse
import hashlib
import json
import os
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# Supported media formats
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff"}
VIDEO_EXTENSIONS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv"}
ALL_EXTENSIONS = IMAGE_EXTENSIONS | VIDEO_EXTENSIONS
def find_media_files(assets_dir: str, max_files: int = 0) -> List[Path]:
"""Scan directory for media files."""
assets_path = Path(assets_dir)
if not assets_path.exists():
print(f"ERROR: Directory not found: {assets_dir}", file=sys.stderr)
return []
media_files = []
for ext in sorted(ALL_EXTENSIONS):
media_files.extend(assets_path.rglob(f"*{ext}"))
media_files.extend(assets_path.rglob(f"*{ext.upper()}"))
# Deduplicate
media_files = sorted(set(media_files))
if max_files > 0:
media_files = media_files[:max_files]
return media_files
def file_hash(filepath: Path) -> str:
"""Generate hash for file deduplication."""
return hashlib.sha256(str(filepath).encode()).hexdigest()[:16]
def generate_description_prompt(filepath: Path) -> str:
"""Generate the prompt for vision model."""
if filepath.suffix.lower() in IMAGE_EXTENSIONS:
return (
"Describe this image as a visual scene for a training dataset. "
"Include: mood, dominant colors (2-3), composition type, camera angle, "
"and a vivid 1-2 sentence description. Format as JSON with keys: "
"mood, colors, composition, camera, description."
)
else:
return (
"Describe this video frame as a visual scene for a training dataset. "
"Include: mood, dominant colors (2-3), composition type, camera movement, "
"and a vivid 1-2 sentence description. Format as JSON with keys: "
"mood, colors, composition, camera, description."
)
def call_vision_model(filepath: Path, model: str = "llava") -> Optional[dict]:
"""
Call a vision model to generate scene description.
Supports:
- llava (local via ollama)
- gpt-4-vision (OpenAI API)
- claude-vision (Anthropic API)
"""
prompt = generate_description_prompt(filepath)
try:
if model.startswith("llava") or model == "ollama":
# Local Ollama with LLaVA
result = subprocess.run(
["curl", "-s", "http://localhost:11434/api/generate", "-d",
json.dumps({
"model": "llava",
"prompt": prompt,
"images": [str(filepath)],
"stream": False,
})],
capture_output=True, text=True, timeout=60
)
if result.returncode == 0:
response = json.loads(result.stdout)
return parse_description(response.get("response", ""))
elif model.startswith("gpt-4"):
# OpenAI GPT-4 Vision (requires API key)
import base64
with open(filepath, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY not set", file=sys.stderr)
return None
result = subprocess.run(
["curl", "-s", "https://api.openai.com/v1/chat/completions",
"-H", f"Authorization: Bearer {api_key}",
"-H", "Content-Type: application/json",
"-d", json.dumps({
"model": "gpt-4-vision-preview",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}
]
}],
"max_tokens": 500
})],
capture_output=True, text=True, timeout=60
)
if result.returncode == 0:
response = json.loads(result.stdout)
content = response["choices"][0]["message"]["content"]
return parse_description(content)
elif model.startswith("claude"):
# Anthropic Claude Vision (requires API key)
import base64
with open(filepath, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
print("ERROR: ANTHROPIC_API_KEY not set", file=sys.stderr)
return None
media_type = "image/jpeg" if filepath.suffix.lower() in {".jpg", ".jpeg"} else "image/png"
result = subprocess.run(
["curl", "-s", "https://api.anthropic.com/v1/messages",
"-H", f"x-api-key: {api_key}",
"-H", "anthropic-version: 2023-06-01",
"-H", "Content-Type: application/json",
"-d", json.dumps({
"model": "claude-3-opus-20240229",
"max_tokens": 500,
"messages": [{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": media_type, "data": image_data}},
{"type": "text", "text": prompt}
]
}]
})],
capture_output=True, text=True, timeout=60
)
if result.returncode == 0:
response = json.loads(result.stdout)
content = response["content"][0]["text"]
return parse_description(content)
except (subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e:
print(f"ERROR calling vision model: {e}", file=sys.stderr)
return None
def parse_description(text: str) -> dict:
"""Parse model response into structured description."""
# Try to extract JSON from response
import re
json_match = re.search(r'\{[^}]+\}', text, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# Fallback: parse manually
desc = {
"mood": "unknown",
"colors": [],
"composition": "unknown",
"camera": "unknown",
"description": text[:500],
}
# Try to extract mood
mood_match = re.search(r'mood["\s:]+(\w+)', text, re.IGNORECASE)
if mood_match:
desc["mood"] = mood_match.group(1).lower()
# Try to extract colors
color_match = re.search(r'colors?["\s:]+\[([^\]]+)\]', text, re.IGNORECASE)
if color_match:
desc["colors"] = [c.strip().strip('"').strip("'") for c in color_match.group(1).split(",")]
return desc
def generate_training_pair(filepath: Path, description: dict, model: str) -> dict:
"""Generate a training pair from media file and description."""
return {
"source_file": str(filepath),
"source_hash": file_hash(filepath),
"source_type": "media_asset",
"media_type": "image" if filepath.suffix.lower() in IMAGE_EXTENSIONS else "video",
"model": model,
"timestamp": datetime.now(timezone.utc).isoformat(),
"source_session_id": f"media-gen-{int(time.time())}",
"prompt": f"Describe the visual scene in {filepath.name}",
"response": description.get("description", ""),
"scene": {
"mood": description.get("mood", "unknown"),
"colors": description.get("colors", []),
"composition": description.get("composition", "unknown"),
"camera": description.get("camera", "unknown"),
"description": description.get("description", ""),
},
}
def main():
parser = argparse.ArgumentParser(description="Generate scene descriptions from media")
parser.add_argument("--assets", required=True, help="Assets directory to scan")
parser.add_argument("--output", help="Output JSONL file path")
parser.add_argument("--model", default="llava", help="Vision model (llava/gpt-4/claude)")
parser.add_argument("--max", type=int, default=0, help="Max files to process (0=all)")
parser.add_argument("--dry-run", action="store_true", help="Don't call vision model")
parser.add_argument("--json", action="store_true", help="JSON output")
args = parser.parse_args()
media_files = find_media_files(args.assets, args.max)
if not media_files:
print("No media files found.", file=sys.stderr)
sys.exit(1)
print(f"Found {len(media_files)} media files in {args.assets}")
if args.dry_run:
print("\nDry run — files to process:")
for f in media_files[:20]:
print(f" {f.relative_to(args.assets)}")
if len(media_files) > 20:
print(f" ... and {len(media_files) - 20} more")
sys.exit(0)
pairs = []
errors = 0
for i, filepath in enumerate(media_files, 1):
print(f"[{i}/{len(media_files)}] Processing {filepath.name}...", end=" ", flush=True)
description = call_vision_model(filepath, args.model)
if description:
pair = generate_training_pair(filepath, description, args.model)
pairs.append(pair)
print(f"OK (mood: {pair['scene']['mood']})")
else:
errors += 1
print("ERROR")
# Output
output_path = args.output or "training-data/media-scene-descriptions.jsonl"
if args.json:
print(json.dumps({"pairs": pairs, "total": len(pairs), "errors": errors}, indent=2))
else:
with open(output_path, 'w') as f:
for pair in pairs:
f.write(json.dumps(pair, ensure_ascii=False) + '\n')
print(f"\nGenerated {len(pairs)} scene descriptions ({errors} errors)")
print(f"Output: {output_path}")
if __name__ == "__main__":
main()