Compare commits
13 Commits
feat/543-t
...
fix/689-au
| Author | SHA1 | Date | |
|---|---|---|---|
| 76a886334b | |||
| e1abecbc54 | |||
| b3f5a2f21c | |||
| e176fadef5 | |||
| 7ca2ebe6b5 | |||
| e9d2cb5e56 | |||
| 990676fb02 | |||
| 3ad934febd | |||
| 35a191f7b1 | |||
| e987e1b870 | |||
| 19278513b4 | |||
| b6e3a647b0 | |||
| e14158676d |
271
bin/preflight-provider-check.py
Normal file
271
bin/preflight-provider-check.py
Normal file
@@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-Flight Provider Check Script
|
||||
Issue #508: [Robustness] Credential drain detection — provider health checks
|
||||
|
||||
Pre-flight check before session launch: verifies provider credentials and balance.
|
||||
|
||||
Usage:
|
||||
python3 preflight-provider-check.py # Check all providers
|
||||
python3 preflight-provider-check.py --launch # Check and return exit code
|
||||
python3 preflight-provider-check.py --balance # Check OpenRouter balance
|
||||
"""
|
||||
|
||||
import os, sys, json, yaml, urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Configuration
|
||||
HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
LOG_DIR = Path.home() / ".local" / "timmy" / "fleet-health"
|
||||
LOG_FILE = LOG_DIR / "preflight-check.log"
|
||||
|
||||
def log(msg):
|
||||
"""Log message to file and optionally console."""
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = "[" + timestamp + "] " + msg
|
||||
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with open(LOG_FILE, "a") as f:
|
||||
f.write(log_entry + "\n")
|
||||
|
||||
if "--quiet" not in sys.argv:
|
||||
print(log_entry)
|
||||
|
||||
def get_provider_api_key(provider):
|
||||
"""Get API key for a provider from .env or environment."""
|
||||
env_file = HERMES_HOME / ".env"
|
||||
if env_file.exists():
|
||||
with open(env_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith(provider.upper() + "_API_KEY="):
|
||||
return line.split("=", 1)[1].strip().strip("'\"")
|
||||
|
||||
return os.environ.get(provider.upper() + "_API_KEY")
|
||||
|
||||
def check_openrouter_balance(api_key):
|
||||
"""Check OpenRouter balance via /api/v1/auth/key."""
|
||||
if not api_key:
|
||||
return False, "No API key", 0
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
"https://openrouter.ai/api/v1/auth/key",
|
||||
headers={"Authorization": "Bearer " + api_key}
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
data = json.loads(resp.read())
|
||||
|
||||
# Check for credits
|
||||
credits = data.get("data", {}).get("limit", 0)
|
||||
usage = data.get("data", {}).get("usage", 0)
|
||||
remaining = credits - usage if credits else None
|
||||
|
||||
if remaining is not None and remaining <= 0:
|
||||
return False, "No credits remaining", 0
|
||||
elif remaining is not None:
|
||||
return True, "Credits available", remaining
|
||||
else:
|
||||
return True, "Unlimited or unknown balance", None
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 401:
|
||||
return False, "Invalid API key", 0
|
||||
else:
|
||||
return False, "HTTP " + str(e.code), 0
|
||||
except Exception as e:
|
||||
return False, str(e)[:100], 0
|
||||
|
||||
def check_nous_key(api_key):
|
||||
"""Check Nous API key with minimal test call."""
|
||||
if not api_key:
|
||||
return False, "No API key"
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
"https://inference.nousresearch.com/v1/models",
|
||||
headers={"Authorization": "Bearer " + api_key}
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
|
||||
if resp.status == 200:
|
||||
return True, "Valid key"
|
||||
else:
|
||||
return False, "HTTP " + str(resp.status)
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 401:
|
||||
return False, "Invalid API key"
|
||||
elif e.code == 403:
|
||||
return False, "Forbidden"
|
||||
else:
|
||||
return False, "HTTP " + str(e.code)
|
||||
except Exception as e:
|
||||
return False, str(e)[:100]
|
||||
|
||||
def check_anthropic_key(api_key):
|
||||
"""Check Anthropic API key with minimal test call."""
|
||||
if not api_key:
|
||||
return False, "No API key"
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
}
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
|
||||
if resp.status == 200:
|
||||
return True, "Valid key"
|
||||
else:
|
||||
return False, "HTTP " + str(resp.status)
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 401:
|
||||
return False, "Invalid API key"
|
||||
elif e.code == 403:
|
||||
return False, "Forbidden"
|
||||
else:
|
||||
return False, "HTTP " + str(e.code)
|
||||
except Exception as e:
|
||||
return False, str(e)[:100]
|
||||
|
||||
def check_ollama():
|
||||
"""Check if Ollama is running."""
|
||||
try:
|
||||
req = urllib.request.Request("http://localhost:11434/api/tags")
|
||||
resp = urllib.request.urlopen(req, timeout=5)
|
||||
|
||||
if resp.status == 200:
|
||||
data = json.loads(resp.read())
|
||||
models = data.get("models", [])
|
||||
return True, str(len(models)) + " models loaded"
|
||||
else:
|
||||
return False, "HTTP " + str(resp.status)
|
||||
|
||||
except Exception as e:
|
||||
return False, str(e)[:100]
|
||||
|
||||
def get_configured_provider():
|
||||
"""Get the configured provider from global config."""
|
||||
config_file = HERMES_HOME / "config.yaml"
|
||||
if not config_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model_config = config.get("model", {})
|
||||
if isinstance(model_config, dict):
|
||||
return model_config.get("provider")
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def run_preflight_check():
|
||||
"""Run pre-flight check on all providers."""
|
||||
log("=== Pre-Flight Provider Check ===")
|
||||
|
||||
results = {}
|
||||
|
||||
# Check OpenRouter
|
||||
or_key = get_provider_api_key("openrouter")
|
||||
or_ok, or_msg, or_balance = check_openrouter_balance(or_key)
|
||||
results["openrouter"] = {"healthy": or_ok, "message": or_msg, "balance": or_balance}
|
||||
|
||||
# Check Nous
|
||||
nous_key = get_provider_api_key("nous")
|
||||
nous_ok, nous_msg = check_nous_key(nous_key)
|
||||
results["nous"] = {"healthy": nous_ok, "message": nous_msg}
|
||||
|
||||
# Check Anthropic
|
||||
anthropic_key = get_provider_api_key("anthropic")
|
||||
anthropic_ok, anthropic_msg = check_anthropic_key(anthropic_key)
|
||||
results["anthropic"] = {"healthy": anthropic_ok, "message": anthropic_msg}
|
||||
|
||||
# Check Ollama
|
||||
ollama_ok, ollama_msg = check_ollama()
|
||||
results["ollama"] = {"healthy": ollama_ok, "message": ollama_msg}
|
||||
|
||||
# Get configured provider
|
||||
configured = get_configured_provider()
|
||||
|
||||
# Summary
|
||||
healthy_count = sum(1 for r in results.values() if r["healthy"])
|
||||
total_count = len(results)
|
||||
|
||||
log("Results: " + str(healthy_count) + "/" + str(total_count) + " providers healthy")
|
||||
|
||||
for provider, result in results.items():
|
||||
status = "HEALTHY" if result["healthy"] else "UNHEALTHY"
|
||||
extra = ""
|
||||
if provider == "openrouter" and result.get("balance") is not None:
|
||||
extra = " (balance: " + str(result["balance"]) + ")"
|
||||
|
||||
log(" " + provider + ": " + status + " - " + result["message"] + extra)
|
||||
|
||||
if configured:
|
||||
log("Configured provider: " + configured)
|
||||
if configured in results and not results[configured]["healthy"]:
|
||||
log("WARNING: Configured provider " + configured + " is UNHEALTHY!")
|
||||
|
||||
return results, configured
|
||||
|
||||
def check_launch_readiness():
|
||||
"""Check if we're ready to launch sessions."""
|
||||
results, configured = run_preflight_check()
|
||||
|
||||
# Check if configured provider is healthy
|
||||
if configured and configured in results:
|
||||
if not results[configured]["healthy"]:
|
||||
log("LAUNCH BLOCKED: Configured provider " + configured + " is unhealthy")
|
||||
return False, configured + " is unhealthy"
|
||||
|
||||
# Check if at least one provider is healthy
|
||||
healthy_providers = [p for p, r in results.items() if r["healthy"]]
|
||||
if not healthy_providers:
|
||||
log("LAUNCH BLOCKED: No healthy providers available")
|
||||
return False, "No healthy providers"
|
||||
|
||||
log("LAUNCH READY: " + str(len(healthy_providers)) + " healthy providers available")
|
||||
return True, "Ready"
|
||||
|
||||
def show_balance():
|
||||
"""Show OpenRouter balance."""
|
||||
api_key = get_provider_api_key("openrouter")
|
||||
if not api_key:
|
||||
print("No OpenRouter API key found")
|
||||
return
|
||||
|
||||
ok, msg, balance = check_openrouter_balance(api_key)
|
||||
|
||||
if ok:
|
||||
if balance is not None:
|
||||
print("OpenRouter balance: " + str(balance) + " credits")
|
||||
else:
|
||||
print("OpenRouter: " + msg)
|
||||
else:
|
||||
print("OpenRouter: " + msg)
|
||||
|
||||
def main():
|
||||
if "--balance" in sys.argv:
|
||||
show_balance()
|
||||
elif "--launch" in sys.argv:
|
||||
ready, message = check_launch_readiness()
|
||||
if ready:
|
||||
print("READY")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("BLOCKED: " + message)
|
||||
sys.exit(1)
|
||||
else:
|
||||
run_preflight_check()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
411
bin/provider-health-monitor.py
Normal file
411
bin/provider-health-monitor.py
Normal file
@@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Provider Health Monitor Script
|
||||
Issue #509: [Robustness] Provider-aware profile config — auto-switch on failure
|
||||
|
||||
Monitors provider health and automatically switches profiles to working providers.
|
||||
|
||||
Usage:
|
||||
python3 provider-health-monitor.py # Run once
|
||||
python3 provider-health-monitor.py --daemon # Run continuously
|
||||
python3 provider-health-monitor.py --status # Show provider health
|
||||
"""
|
||||
|
||||
import os, sys, json, yaml, urllib.request, time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Configuration
|
||||
HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
PROFILES_DIR = HERMES_HOME / "profiles"
|
||||
LOG_DIR = Path.home() / ".local" / "timmy" / "fleet-health"
|
||||
STATE_FILE = LOG_DIR / "tmux-state.json"
|
||||
LOG_FILE = LOG_DIR / "provider-health.log"
|
||||
|
||||
# Provider test endpoints
|
||||
PROVIDER_TESTS = {
|
||||
"openrouter": {
|
||||
"url": "https://openrouter.ai/api/v1/models",
|
||||
"method": "GET",
|
||||
"headers": lambda api_key: {"Authorization": "Bearer " + api_key},
|
||||
"timeout": 10
|
||||
},
|
||||
"anthropic": {
|
||||
"url": "https://api.anthropic.com/v1/models",
|
||||
"method": "GET",
|
||||
"headers": lambda api_key: {"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||
"timeout": 10
|
||||
},
|
||||
"nous": {
|
||||
"url": "https://inference.nousresearch.com/v1/models",
|
||||
"method": "GET",
|
||||
"headers": lambda api_key: {"Authorization": "Bearer " + api_key},
|
||||
"timeout": 10
|
||||
},
|
||||
"kimi-coding": {
|
||||
"url": "https://api.kimi.com/coding/v1/models",
|
||||
"method": "GET",
|
||||
"headers": lambda api_key: {"x-api-key": api_key, "x-api-provider": "kimi-coding"},
|
||||
"timeout": 10
|
||||
},
|
||||
"ollama": {
|
||||
"url": "http://localhost:11434/api/tags",
|
||||
"method": "GET",
|
||||
"headers": lambda api_key: {},
|
||||
"timeout": 5
|
||||
}
|
||||
}
|
||||
|
||||
def log(msg):
|
||||
"""Log message to file and optionally console."""
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = "[" + timestamp + "] " + msg
|
||||
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with open(LOG_FILE, "a") as f:
|
||||
f.write(log_entry + "\n")
|
||||
|
||||
if "--quiet" not in sys.argv:
|
||||
print(log_entry)
|
||||
|
||||
def get_provider_api_key(provider):
|
||||
"""Get API key for a provider from .env or environment."""
|
||||
env_file = HERMES_HOME / ".env"
|
||||
if env_file.exists():
|
||||
with open(env_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith(provider.upper() + "_API_KEY="):
|
||||
return line.split("=", 1)[1].strip().strip("'\"")
|
||||
|
||||
return os.environ.get(provider.upper() + "_API_KEY")
|
||||
|
||||
def test_provider(provider, api_key=None):
|
||||
"""Test if a provider is healthy."""
|
||||
config = PROVIDER_TESTS.get(provider)
|
||||
if not config:
|
||||
return False, "Unknown provider: " + provider
|
||||
|
||||
headers = config["headers"](api_key or "")
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
config["url"],
|
||||
headers=headers,
|
||||
method=config["method"]
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=config["timeout"])
|
||||
|
||||
if resp.status == 200:
|
||||
return True, "Healthy"
|
||||
else:
|
||||
return False, "HTTP " + str(resp.status)
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 401:
|
||||
return False, "Unauthorized (401)"
|
||||
elif e.code == 403:
|
||||
return False, "Forbidden (403)"
|
||||
elif e.code == 429:
|
||||
return True, "Rate limited but accessible"
|
||||
else:
|
||||
return False, "HTTP " + str(e.code)
|
||||
except Exception as e:
|
||||
return False, str(e)[:100]
|
||||
|
||||
def get_all_providers():
|
||||
"""Get all providers from profiles and global config."""
|
||||
providers = set()
|
||||
|
||||
# Global config
|
||||
global_config = HERMES_HOME / "config.yaml"
|
||||
if global_config.exists():
|
||||
try:
|
||||
with open(global_config) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Primary model provider
|
||||
model_config = config.get("model", {})
|
||||
if isinstance(model_config, dict):
|
||||
provider = model_config.get("provider", "")
|
||||
if provider:
|
||||
providers.add(provider)
|
||||
|
||||
# Auxiliary providers
|
||||
auxiliary = config.get("auxiliary", {})
|
||||
for aux_config in auxiliary.values():
|
||||
if isinstance(aux_config, dict):
|
||||
provider = aux_config.get("provider", "")
|
||||
if provider and provider != "auto":
|
||||
providers.add(provider)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Profile configs
|
||||
if PROFILES_DIR.exists():
|
||||
for profile_dir in PROFILES_DIR.iterdir():
|
||||
if profile_dir.is_dir():
|
||||
config_file = profile_dir / "config.yaml"
|
||||
if config_file.exists():
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model_config = config.get("model", {})
|
||||
if isinstance(model_config, dict):
|
||||
provider = model_config.get("provider", "")
|
||||
if provider:
|
||||
providers.add(provider)
|
||||
|
||||
auxiliary = config.get("auxiliary", {})
|
||||
for aux_config in auxiliary.values():
|
||||
if isinstance(aux_config, dict):
|
||||
provider = aux_config.get("provider", "")
|
||||
if provider and provider != "auto":
|
||||
providers.add(provider)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common providers even if not configured
|
||||
providers.update(["openrouter", "nous", "ollama"])
|
||||
|
||||
return list(providers)
|
||||
|
||||
def build_health_map():
|
||||
"""Build a health map of all providers."""
|
||||
providers = get_all_providers()
|
||||
health_map = {}
|
||||
|
||||
log("Testing " + str(len(providers)) + " providers...")
|
||||
|
||||
for provider in providers:
|
||||
api_key = get_provider_api_key(provider)
|
||||
healthy, message = test_provider(provider, api_key)
|
||||
|
||||
health_map[provider] = {
|
||||
"healthy": healthy,
|
||||
"message": message,
|
||||
"last_test": datetime.now(timezone.utc).isoformat(),
|
||||
"api_key_present": bool(api_key)
|
||||
}
|
||||
|
||||
status = "HEALTHY" if healthy else "UNHEALTHY"
|
||||
log(" " + provider + ": " + status + " - " + message)
|
||||
|
||||
return health_map
|
||||
|
||||
def get_fallback_providers(health_map):
|
||||
"""Get list of healthy providers in priority order."""
|
||||
# Priority order: nous, openrouter, ollama, others
|
||||
priority_order = ["nous", "openrouter", "ollama", "anthropic", "kimi-coding"]
|
||||
|
||||
healthy = []
|
||||
for provider in priority_order:
|
||||
if provider in health_map and health_map[provider]["healthy"]:
|
||||
healthy.append(provider)
|
||||
|
||||
# Add any other healthy providers not in priority list
|
||||
for provider, info in health_map.items():
|
||||
if info["healthy"] and provider not in healthy:
|
||||
healthy.append(provider)
|
||||
|
||||
return healthy
|
||||
|
||||
def update_profile_config(profile_name, new_provider):
|
||||
"""Update a profile's config to use a new provider."""
|
||||
config_file = PROFILES_DIR / profile_name / "config.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return False, "Config file not found"
|
||||
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Update model provider
|
||||
if "model" not in config:
|
||||
config["model"] = {}
|
||||
|
||||
old_provider = config["model"].get("provider", "unknown")
|
||||
config["model"]["provider"] = new_provider
|
||||
|
||||
# Update auxiliary providers if they were using the old provider
|
||||
auxiliary = config.get("auxiliary", {})
|
||||
for aux_name, aux_config in auxiliary.items():
|
||||
if isinstance(aux_config, dict) and aux_config.get("provider") == old_provider:
|
||||
aux_config["provider"] = new_provider
|
||||
|
||||
# Write back
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
log("Updated " + profile_name + ": " + old_provider + " -> " + new_provider)
|
||||
return True, "Updated"
|
||||
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
def check_profiles(health_map):
|
||||
"""Check all profiles and update unhealthy providers."""
|
||||
if not PROFILES_DIR.exists():
|
||||
return
|
||||
|
||||
fallback_providers = get_fallback_providers(health_map)
|
||||
if not fallback_providers:
|
||||
log("CRITICAL: No healthy providers available!")
|
||||
return
|
||||
|
||||
updated_profiles = []
|
||||
|
||||
for profile_dir in PROFILES_DIR.iterdir():
|
||||
if not profile_dir.is_dir():
|
||||
continue
|
||||
|
||||
profile_name = profile_dir.name
|
||||
config_file = profile_dir / "config.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(config_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
model_config = config.get("model", {})
|
||||
if not isinstance(model_config, dict):
|
||||
continue
|
||||
|
||||
current_provider = model_config.get("provider", "")
|
||||
if not current_provider:
|
||||
continue
|
||||
|
||||
# Check if current provider is healthy
|
||||
if current_provider in health_map and health_map[current_provider]["healthy"]:
|
||||
continue # Provider is healthy, no action needed
|
||||
|
||||
# Find best fallback
|
||||
best_fallback = None
|
||||
for provider in fallback_providers:
|
||||
if provider != current_provider:
|
||||
best_fallback = provider
|
||||
break
|
||||
|
||||
if not best_fallback:
|
||||
log("No fallback for " + profile_name + " (current: " + current_provider + ")")
|
||||
continue
|
||||
|
||||
# Update profile
|
||||
success, message = update_profile_config(profile_name, best_fallback)
|
||||
if success:
|
||||
updated_profiles.append({
|
||||
"profile": profile_name,
|
||||
"old_provider": current_provider,
|
||||
"new_provider": best_fallback
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
log("Error processing " + profile_name + ": " + str(e))
|
||||
|
||||
return updated_profiles
|
||||
|
||||
def load_state():
|
||||
"""Load state from tmux-state.json."""
|
||||
if STATE_FILE.exists():
|
||||
try:
|
||||
with open(STATE_FILE) as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
pass
|
||||
return {}
|
||||
|
||||
def save_state(state):
|
||||
"""Save state to tmux-state.json."""
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(STATE_FILE, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
def run_once():
|
||||
"""Run provider health check once."""
|
||||
log("=== Provider Health Check ===")
|
||||
|
||||
state = load_state()
|
||||
|
||||
# Build health map
|
||||
health_map = build_health_map()
|
||||
|
||||
# Check profiles and update if needed
|
||||
updated_profiles = check_profiles(health_map)
|
||||
|
||||
# Update state
|
||||
state["provider_health"] = health_map
|
||||
state["last_provider_check"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if updated_profiles:
|
||||
state["last_profile_updates"] = updated_profiles
|
||||
|
||||
save_state(state)
|
||||
|
||||
# Summary
|
||||
healthy_count = sum(1 for p in health_map.values() if p["healthy"])
|
||||
total_count = len(health_map)
|
||||
|
||||
log("Health: " + str(healthy_count) + "/" + str(total_count) + " providers healthy")
|
||||
|
||||
if updated_profiles:
|
||||
log("Updated " + str(len(updated_profiles)) + " profiles:")
|
||||
for update in updated_profiles:
|
||||
log(" " + update["profile"] + ": " + update["old_provider"] + " -> " + update["new_provider"])
|
||||
|
||||
def show_status():
|
||||
"""Show provider health status."""
|
||||
state = load_state()
|
||||
health_map = state.get("provider_health", {})
|
||||
|
||||
if not health_map:
|
||||
print("No provider health data available. Run without --status first.")
|
||||
return
|
||||
|
||||
print("Provider Health (last updated: " + str(state.get("last_provider_check", "unknown")) + ")")
|
||||
print("=" * 80)
|
||||
|
||||
for provider, info in sorted(health_map.items()):
|
||||
status = "HEALTHY" if info["healthy"] else "UNHEALTHY"
|
||||
message = info.get("message", "")
|
||||
api_key = "yes" if info.get("api_key_present") else "no"
|
||||
|
||||
print(provider.ljust(20) + " " + status.ljust(10) + " API key: " + api_key + " - " + message)
|
||||
|
||||
# Show recent updates
|
||||
updates = state.get("last_profile_updates", [])
|
||||
if updates:
|
||||
print()
|
||||
print("Recent Profile Updates:")
|
||||
for update in updates:
|
||||
print(" " + update["profile"] + ": " + update["old_provider"] + " -> " + update["new_provider"])
|
||||
|
||||
def daemon_mode():
|
||||
"""Run continuously."""
|
||||
log("Starting provider health daemon (check every 300s)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
run_once()
|
||||
time.sleep(300) # Check every 5 minutes
|
||||
except KeyboardInterrupt:
|
||||
log("Daemon stopped by user")
|
||||
break
|
||||
except Exception as e:
|
||||
log("Error: " + str(e))
|
||||
time.sleep(60)
|
||||
|
||||
def main():
|
||||
if "--status" in sys.argv:
|
||||
show_status()
|
||||
elif "--daemon" in sys.argv:
|
||||
daemon_mode()
|
||||
else:
|
||||
run_once()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
docs/visual-evidence-689.md
Normal file
74
docs/visual-evidence-689.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Visual Evidence — Gemma 4 Multimodal Scene Description Generator
|
||||
|
||||
## Test Image: Coffee Beans (Macro Photo)
|
||||
|
||||
### Gemma 4 Vision Analysis (via Ollama)
|
||||
|
||||
**Model:** gemma4:latest (8B, Q4_K_M)
|
||||
**Input:** sample_photo.jpg (46KB JPEG)
|
||||
|
||||
**Structured Output (JSONL):**
|
||||
```json
|
||||
{
|
||||
"mood": "dark",
|
||||
"colors": ["dark brown", "espresso", "black"],
|
||||
"composition": "close-up",
|
||||
"camera": "static",
|
||||
"lighting": "soft",
|
||||
"description": "An extreme close-up shot captures a dense pile of roasted coffee beans. The beans are a uniform, deep dark brown and appear slightly oily, filling the entire frame. The focus emphasizes the rich texture and individual shapes of the beans."
|
||||
}
|
||||
```
|
||||
|
||||
### Hermes Vision Analysis (Cross-Validation)
|
||||
|
||||
**Scene ID:** COFFEE_MACRO_001
|
||||
**Mood:** Warm, aromatic, and comforting
|
||||
**Dominant Colors:** Deep umber, burnt sienna, espresso black, mahogany
|
||||
**Composition:** Full-frame fill, centrally weighted
|
||||
**Camera:** High-angle, close-up (Macro)
|
||||
**Lighting:** Soft, diffused top-lighting
|
||||
|
||||
## Test Image: Abstract Geometric Composition
|
||||
|
||||
### Gemma 4 Vision Analysis
|
||||
|
||||
**Input:** scene1.jpg (10KB, PIL-generated)
|
||||
|
||||
**Structured Output (JSONL):**
|
||||
```json
|
||||
{
|
||||
"mood": "energetic",
|
||||
"colors": ["deep blue", "yellow", "coral"],
|
||||
"composition": "wide-shot",
|
||||
"camera": "static",
|
||||
"lighting": "artificial",
|
||||
"description": "This is an abstract graphic composition set against a solid, deep blue background. A bright yellow square is placed in the upper left quadrant, while a large, solid coral-colored circle occupies the lower right quadrant. The geometric shapes create a high-contrast, minimalist visual balance."
|
||||
}
|
||||
```
|
||||
|
||||
## Verification Summary
|
||||
|
||||
| Test | Status | Details |
|
||||
|------|--------|---------|
|
||||
| Model detection | ✅ PASS | `gemma4:latest` auto-detected |
|
||||
| Image scanning | ✅ PASS | 2 images found recursively |
|
||||
| Vision analysis | ✅ PASS | Both images described accurately |
|
||||
| JSON parsing | ✅ PASS | Structured output with all fields |
|
||||
| Training format | ✅ PASS | JSONL with source, model, timestamp |
|
||||
| ShareGPT format | ⚠️ PARTIAL | Works but needs retry on rate limit |
|
||||
|
||||
## Running the Generator
|
||||
|
||||
```bash
|
||||
# Check model availability
|
||||
python scripts/generate_scene_descriptions.py --check-model
|
||||
|
||||
# Generate scene descriptions from assets
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --output training-data/scene-descriptions-auto.jsonl
|
||||
|
||||
# Limit to 10 files with specific model
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --model gemma4:latest --limit 10
|
||||
|
||||
# ShareGPT format for training pipeline
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --format sharegpt
|
||||
```
|
||||
409
scripts/generate_scene_descriptions.py
Normal file
409
scripts/generate_scene_descriptions.py
Normal file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-generate scene descriptions from image/video assets.
|
||||
|
||||
Scans a directory for media files, generates scene descriptions using
|
||||
a local vision model (Ollama), and outputs training pairs in JSONL format.
|
||||
|
||||
Supports Gemma 4 multimodal vision via Ollama. Falls back gracefully when
|
||||
models are unavailable.
|
||||
|
||||
Usage:
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --output training-data/scene-descriptions-auto.jsonl
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --model gemma4:latest --limit 50
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --format sharegpt
|
||||
python scripts/generate_scene_descriptions.py --dry-run # List files without generating
|
||||
python scripts/generate_scene_descriptions.py --input ./assets --check-model # Verify model availability
|
||||
|
||||
Ref: timmy-config#689
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Supported media extensions
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
VIDEO_EXTS = {".mp4", ".webm", ".mov", ".avi", ".mkv"}
|
||||
ALL_EXTS = IMAGE_EXTS | VIDEO_EXTS
|
||||
|
||||
# File size limit (50MB) — prevents unbounded memory usage on large images
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Vision models in preference order (best first)
|
||||
VISION_MODELS = [
|
||||
"gemma4:latest", # Gemma 4 — multimodal vision (8B, Q4_K_M)
|
||||
"gemma3:12b", # Gemma 3 — fallback vision
|
||||
"llava:latest", # LLaVA — generic vision
|
||||
"llava-phi3:latest", # LLaVA-Phi3 — lightweight vision
|
||||
]
|
||||
|
||||
# Vision model prompt template (structured JSON output)
|
||||
SCENE_PROMPT = """Describe this image for a visual scene database. Output ONLY valid JSON (no markdown, no explanation):
|
||||
{
|
||||
"mood": "one of: calm, energetic, dark, warm, cool, chaotic, serene, tense, joyful, melancholic",
|
||||
"colors": ["dominant color 1", "dominant color 2", "dominant color 3"],
|
||||
"composition": "one of: close-up, wide-shot, medium-shot, low-angle, high-angle, bird-eye, profile, over-shoulder",
|
||||
"camera": "one of: static, slow-pan, tracking, handheld, crane, dolly, steady, locked-off",
|
||||
"lighting": "one of: natural, artificial, mixed, dramatic, soft, harsh, backlit",
|
||||
"description": "2-3 sentence visual description of the scene"
|
||||
}
|
||||
|
||||
Be specific. Describe what you see, not what you imagine."""
|
||||
|
||||
# ShareGPT format prompt (for training pipeline integration)
|
||||
SHAREGPT_SCENE_PROMPT = """Analyze this image and describe the visual scene. Include mood, dominant colors, composition, camera angle, lighting, and a vivid 2-3 sentence description."""
|
||||
|
||||
|
||||
def check_model_available(model: str, ollama_url: str = "http://localhost:11434") -> bool:
|
||||
"""Check if a model is available in Ollama."""
|
||||
try:
|
||||
req = urllib.request.Request(f"{ollama_url}/api/tags")
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
data = json.loads(resp.read())
|
||||
available = [m["name"] for m in data.get("models", [])]
|
||||
return model in available
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def auto_detect_model(ollama_url: str = "http://localhost:11434") -> Optional[str]:
|
||||
"""Auto-detect the best available vision model."""
|
||||
for model in VISION_MODELS:
|
||||
if check_model_available(model, ollama_url):
|
||||
print(f"Auto-detected vision model: {model}", file=sys.stderr)
|
||||
return model
|
||||
return None
|
||||
|
||||
|
||||
def scan_media(input_dir: str) -> list[Path]:
|
||||
"""Scan directory for media files recursively."""
|
||||
media_files = []
|
||||
input_path = Path(input_dir)
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_dir} does not exist", file=sys.stderr)
|
||||
return media_files
|
||||
|
||||
for ext in sorted(ALL_EXTS):
|
||||
media_files.extend(input_path.rglob(f"*{ext}"))
|
||||
media_files.extend(input_path.rglob(f"*{ext.upper()}"))
|
||||
|
||||
return sorted(set(media_files))
|
||||
|
||||
|
||||
def extract_video_frame(video_path: Path, output_path: Path) -> bool:
|
||||
"""Extract a representative frame from a video using ffmpeg."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
# FIX #3: Seek 2s in before grabbing frame — avoids black/title frames
|
||||
["ffmpeg", "-ss", "2", "-i", str(video_path), "-vframes", "1",
|
||||
"-q:v", "2", str(output_path), "-y"],
|
||||
capture_output=True, timeout=30,
|
||||
)
|
||||
if result.returncode != 0 and result.stderr:
|
||||
print(f" ffmpeg stderr: {result.stderr.decode(errors='replace')[:200]}", file=sys.stderr)
|
||||
return output_path.exists() and output_path.stat().st_size > 0
|
||||
except FileNotFoundError:
|
||||
print(" ffmpeg not found — skipping video frame extraction", file=sys.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ffmpeg error: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def describe_image(
|
||||
image_path: Path,
|
||||
model: str = "gemma4:latest",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
max_retries: int = 2,
|
||||
) -> Optional[dict]:
|
||||
"""Generate scene description using Ollama vision model with retry."""
|
||||
# FIX #1: Check file size before reading into memory
|
||||
if image_path.stat().st_size > MAX_FILE_SIZE:
|
||||
print(f" Skipping {image_path.name}: exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit", file=sys.stderr)
|
||||
return None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{ollama_url}/api/generate",
|
||||
data=json.dumps({
|
||||
"model": model,
|
||||
"prompt": SCENE_PROMPT,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 1024}
|
||||
}).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
data = json.loads(resp.read())
|
||||
response_text = data.get("response", "")
|
||||
|
||||
# Parse JSON from response (handle both complete and truncated JSON)
|
||||
json_match = re.search(r"\{[\s\S]*\}", response_text)
|
||||
if not json_match:
|
||||
# Try to find opening brace for truncated JSON
|
||||
brace_match = re.search(r"\{", response_text)
|
||||
if brace_match:
|
||||
json_match = brace_match
|
||||
|
||||
if json_match:
|
||||
raw_json = json_match.group() if hasattr(json_match, 'group') else response_text[json_match.start():]
|
||||
# Try strict parse first
|
||||
try:
|
||||
parsed = json.loads(raw_json)
|
||||
required = ["mood", "colors", "composition", "camera", "description"]
|
||||
if all(k in parsed for k in required) and parsed.get("description"):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
# Attempt repair: extract fields from truncated JSON
|
||||
repaired = {}
|
||||
for field in ["mood", "colors", "composition", "camera", "lighting", "description"]:
|
||||
pat = rf'"\s*{field}"\s*:\s*"([^"]*)"'
|
||||
m = re.search(pat, response_text)
|
||||
if m:
|
||||
repaired[field] = m.group(1)
|
||||
elif field == "colors":
|
||||
colors_match = re.search(r'"colors"\s*:\s*\[([^\]]*)\]', response_text)
|
||||
if colors_match:
|
||||
repaired[field] = [c.strip().strip('"') for c in colors_match.group(1).split(",") if c.strip()]
|
||||
else:
|
||||
repaired[field] = []
|
||||
else:
|
||||
repaired[field] = "unknown"
|
||||
if repaired.get("description") or repaired.get("mood") != "unknown":
|
||||
return repaired
|
||||
|
||||
# Final fallback: natural language response
|
||||
clean = re.sub(r"[*_`#]", "", response_text).strip()
|
||||
clean = re.sub(r"\n{3,}", "\n\n", clean)
|
||||
return {
|
||||
"description": clean[:500] if clean else response_text[:500],
|
||||
"mood": "unknown",
|
||||
"colors": [],
|
||||
"composition": "unknown",
|
||||
"camera": "unknown",
|
||||
"lighting": "unknown"
|
||||
}
|
||||
|
||||
except (urllib.error.URLError, TimeoutError) as e:
|
||||
if attempt < max_retries:
|
||||
wait = 2 ** attempt
|
||||
print(f" Retry {attempt + 1}/{max_retries} after {wait}s: {e}", file=sys.stderr)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
print(f" Error describing {image_path.name}: {e}", file=sys.stderr)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" Error describing {image_path.name}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def describe_image_sharegpt(
|
||||
image_path: Path,
|
||||
model: str = "gemma4:latest",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
max_retries: int = 2,
|
||||
) -> Optional[str]:
|
||||
"""Generate scene description in natural language for ShareGPT format."""
|
||||
# FIX #1: Check file size before reading into memory
|
||||
if image_path.stat().st_size > MAX_FILE_SIZE:
|
||||
print(f" Skipping {image_path.name}: exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit", file=sys.stderr)
|
||||
return None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_b64 = base64.b64encode(f.read()).decode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{ollama_url}/api/generate",
|
||||
data=json.dumps({
|
||||
"model": model,
|
||||
"prompt": SHAREGPT_SCENE_PROMPT,
|
||||
"images": [image_b64],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.5, "num_predict": 256}
|
||||
}).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
data = json.loads(resp.read())
|
||||
return data.get("response", "").strip()
|
||||
|
||||
except (urllib.error.URLError, TimeoutError) as e:
|
||||
if attempt < max_retries:
|
||||
time.sleep(2 ** attempt)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_training_pairs(
|
||||
media_files: list[Path],
|
||||
model: str,
|
||||
ollama_url: str,
|
||||
limit: int = 0,
|
||||
dry_run: bool = False,
|
||||
output_format: str = "jsonl",
|
||||
) -> list[dict]:
|
||||
"""Generate training pairs from media files."""
|
||||
pairs = []
|
||||
files = media_files[:limit] if limit > 0 else media_files
|
||||
|
||||
print(f"Processing {len(files)} files with model {model}...", file=sys.stderr)
|
||||
|
||||
for i, media_path in enumerate(files):
|
||||
print(f" [{i + 1}/{len(files)}] {media_path.name}...", file=sys.stderr, end=" ", flush=True)
|
||||
|
||||
if dry_run:
|
||||
print("(dry run)", file=sys.stderr)
|
||||
pairs.append({"source": str(media_path), "status": "dry-run"})
|
||||
continue
|
||||
|
||||
is_video = media_path.suffix.lower() in VIDEO_EXTS
|
||||
work_path = media_path
|
||||
|
||||
if is_video:
|
||||
frame_path = media_path.with_suffix(".frame.jpg")
|
||||
if extract_video_frame(media_path, frame_path):
|
||||
work_path = frame_path
|
||||
else:
|
||||
print("SKIP (frame extraction failed)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
try:
|
||||
if output_format == "sharegpt":
|
||||
# ShareGPT format for training pipeline
|
||||
description = describe_image_sharegpt(work_path, model, ollama_url)
|
||||
if description:
|
||||
pair = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": f"<image>\n{SHAREGPT_SCENE_PROMPT}"},
|
||||
{"from": "gpt", "value": description}
|
||||
],
|
||||
"source": str(media_path),
|
||||
"media_type": "video" if is_video else "image",
|
||||
"model": model,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
pairs.append(pair)
|
||||
print("OK", file=sys.stderr)
|
||||
else:
|
||||
print("FAIL", file=sys.stderr)
|
||||
else:
|
||||
# Structured JSONL format
|
||||
description = describe_image(work_path, model, ollama_url)
|
||||
if description:
|
||||
pair = {
|
||||
"source": str(media_path),
|
||||
"media_type": "video" if is_video else "image",
|
||||
"description": description,
|
||||
"model": model,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
pairs.append(pair)
|
||||
print("OK", file=sys.stderr)
|
||||
else:
|
||||
print("FAIL", file=sys.stderr)
|
||||
finally:
|
||||
# FIX #6: Cleanup temp frame in try/finally — survives crashes
|
||||
if is_video and work_path != media_path:
|
||||
try:
|
||||
work_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Small delay between files (reduced from 0.5s — Ollama is local)
|
||||
time.sleep(0.1)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Auto-generate scene descriptions from media assets using vision AI"
|
||||
)
|
||||
parser.add_argument("--input", "-i", default="", help="Input directory with media files")
|
||||
parser.add_argument("--output", "-o", default="training-data/scene-descriptions-auto.jsonl")
|
||||
parser.add_argument("--model", "-m", default=None, help="Ollama model name (auto-detects best available if omitted)")
|
||||
parser.add_argument("--ollama-url", default="http://localhost:11434")
|
||||
parser.add_argument("--limit", "-l", type=int, default=0, help="Max files to process (0=all)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="List files without generating")
|
||||
parser.add_argument("--check-model", action="store_true", help="Check model availability and exit")
|
||||
parser.add_argument("--format", choices=["jsonl", "sharegpt"], default="jsonl",
|
||||
help="Output format: jsonl (structured) or sharegpt (training pipeline)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Model detection
|
||||
if args.check_model:
|
||||
if args.model:
|
||||
available = check_model_available(args.model, args.ollama_url)
|
||||
print(f"Model '{args.model}': {'✅ available' if available else '❌ not found'}")
|
||||
else:
|
||||
model = auto_detect_model(args.ollama_url)
|
||||
if model:
|
||||
print(f"✅ Best available: {model}")
|
||||
else:
|
||||
print("❌ No vision models found in Ollama — install one with: ollama pull gemma4:latest")
|
||||
sys.exit(0)
|
||||
|
||||
# Auto-detect model if not specified
|
||||
model = args.model
|
||||
if not model:
|
||||
model = auto_detect_model(args.ollama_url)
|
||||
if not model:
|
||||
# Fall back to best default even if not installed — let Ollama handle the error
|
||||
model = "gemma4:latest"
|
||||
print(f"Warning: No vision models detected. Falling back to {model}", file=sys.stderr)
|
||||
|
||||
# Validate input
|
||||
if not args.input:
|
||||
print("Error: --input is required (unless using --check-model)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Scan and process
|
||||
media_files = scan_media(args.input)
|
||||
print(f"Found {len(media_files)} media files", file=sys.stderr)
|
||||
|
||||
if not media_files:
|
||||
print("No media files found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
pairs = generate_training_pairs(
|
||||
media_files, model, args.ollama_url,
|
||||
args.limit, args.dry_run, args.format
|
||||
)
|
||||
|
||||
# Write output
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
for pair in pairs:
|
||||
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"\nWrote {len(pairs)} pairs to {output_path}", file=sys.stderr)
|
||||
|
||||
# Summary
|
||||
success = len([p for p in pairs if "description" in p or "conversations" in p])
|
||||
failed = len(pairs) - success
|
||||
if failed > 0:
|
||||
print(f" ⚠️ {failed} files failed", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
333
tests/test_scene_descriptions.py
Normal file
333
tests/test_scene_descriptions.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for generate_scene_descriptions.py
|
||||
|
||||
Tests the scene description generation pipeline including:
|
||||
- Media file scanning
|
||||
- Model detection
|
||||
- JSON parsing from vision responses
|
||||
- Output format validation
|
||||
|
||||
Ref: timmy-config#689
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Add scripts to path for import
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
|
||||
|
||||
from generate_scene_descriptions import (
|
||||
IMAGE_EXTS,
|
||||
VIDEO_EXTS,
|
||||
ALL_EXTS,
|
||||
VISION_MODELS,
|
||||
auto_detect_model,
|
||||
check_model_available,
|
||||
scan_media,
|
||||
extract_video_frame,
|
||||
)
|
||||
|
||||
|
||||
class TestMediaScanning(unittest.TestCase):
|
||||
"""Test media file scanning."""
|
||||
|
||||
def test_scan_empty_directory(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_scan_nonexistent_directory(self):
|
||||
result = scan_media("/nonexistent/path/that/does/not/exist")
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_scan_with_images(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create test files
|
||||
for ext in [".jpg", ".png", ".webp"]:
|
||||
(Path(tmpdir) / f"test{ext}").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 3)
|
||||
|
||||
def test_scan_recursive(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
subdir = Path(tmpdir) / "sub" / "dir"
|
||||
subdir.mkdir(parents=True)
|
||||
(subdir / "deep.jpg").touch()
|
||||
(Path(tmpdir) / "top.png").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_scan_ignores_unsupported(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "image.jpg").touch()
|
||||
(Path(tmpdir) / "document.pdf").touch()
|
||||
(Path(tmpdir) / "script.py").touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_scan_sorted_output(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for name in ["z.jpg", "a.png", "m.webp"]:
|
||||
(Path(tmpdir) / name).touch()
|
||||
|
||||
result = scan_media(tmpdir)
|
||||
names = [p.name for p in result]
|
||||
self.assertEqual(names, sorted(names))
|
||||
|
||||
|
||||
class TestModelDetection(unittest.TestCase):
|
||||
"""Test model availability detection."""
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_check_model_available(self, mock_urlopen):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps({
|
||||
"models": [{"name": "gemma4:latest"}]
|
||||
}).encode()
|
||||
mock_urlopen.return_value.__enter__ = MagicMock(return_value=mock_resp)
|
||||
mock_urlopen.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = check_model_available("gemma4:latest")
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_check_model_not_available(self, mock_urlopen):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps({
|
||||
"models": [{"name": "llama2:7b"}]
|
||||
}).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
result = check_model_available("gemma4:latest")
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_prefers_gemma4(self, mock_check):
|
||||
def side_effect(model, url):
|
||||
return model == "gemma4:latest"
|
||||
mock_check.side_effect = side_effect
|
||||
|
||||
result = auto_detect_model()
|
||||
self.assertEqual(result, "gemma4:latest")
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_falls_back(self, mock_check):
|
||||
def side_effect(model, url):
|
||||
return model == "llava:latest"
|
||||
mock_check.side_effect = side_effect
|
||||
|
||||
result = auto_detect_model()
|
||||
self.assertEqual(result, "llava:latest")
|
||||
|
||||
@patch('generate_scene_descriptions.check_model_available')
|
||||
def test_auto_detect_returns_none_when_no_models(self, mock_check):
|
||||
mock_check.return_value = False
|
||||
result = auto_detect_model()
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestConstants(unittest.TestCase):
|
||||
"""Test constant definitions."""
|
||||
|
||||
def test_image_extensions(self):
|
||||
self.assertIn(".jpg", IMAGE_EXTS)
|
||||
self.assertIn(".png", IMAGE_EXTS)
|
||||
self.assertIn(".webp", IMAGE_EXTS)
|
||||
|
||||
def test_video_extensions(self):
|
||||
self.assertIn(".mp4", VIDEO_EXTS)
|
||||
self.assertIn(".webm", VIDEO_EXTS)
|
||||
|
||||
def test_all_extensions_union(self):
|
||||
self.assertEqual(ALL_EXTS, IMAGE_EXTS | VIDEO_EXTS)
|
||||
|
||||
def test_vision_models_ordered(self):
|
||||
self.assertEqual(VISION_MODELS[0], "gemma4:latest")
|
||||
self.assertIn("llava:latest", VISION_MODELS)
|
||||
|
||||
|
||||
class TestVideoFrameExtraction(unittest.TestCase):
|
||||
"""Test video frame extraction."""
|
||||
|
||||
def test_extract_nonexistent_video(self):
|
||||
result = extract_video_frame(Path("/nonexistent.mp4"), Path("/tmp/frame.jpg"))
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDescribeImage(unittest.TestCase):
|
||||
"""Test describe_image() with mocked Ollama responses."""
|
||||
|
||||
def test_skips_oversized_file(self):
|
||||
"""Files exceeding MAX_FILE_SIZE should be skipped without API call."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * (51 * 1024 * 1024))
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_parses_valid_json_response(self, mock_urlopen):
|
||||
"""Valid JSON response should be parsed and returned."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {
|
||||
"response": '{"mood": "calm", "colors": ["blue", "white"], "composition": "wide-shot", "camera": "static", "lighting": "natural", "description": "A serene ocean scene."}'
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result["mood"], "calm")
|
||||
self.assertIn("lighting", result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_repair_truncated_json(self, mock_urlopen):
|
||||
"""Truncated JSON should be repaired with regex extraction."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {
|
||||
"response": '{"mood": "dark", "colors": ["red"], "composition": "close-up", "camera": "handheld", "lighting": "dramatic", "description": "A shadowy figure in a dimly lit alley'
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result["mood"], "dark")
|
||||
self.assertEqual(result["lighting"], "dramatic")
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_fallback_on_invalid_json(self, mock_urlopen):
|
||||
"""Completely invalid JSON response should still return a fallback."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {"response": "This is just plain text describing a beautiful sunset over mountains."}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("description", result)
|
||||
self.assertIn("lighting", result)
|
||||
|
||||
|
||||
class TestDescribeImageSharegpt(unittest.TestCase):
|
||||
"""Test describe_image_sharegpt() with mocked Ollama responses."""
|
||||
|
||||
def test_skips_oversized_file(self):
|
||||
"""Files exceeding MAX_FILE_SIZE should be skipped."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * (51 * 1024 * 1024))
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image_sharegpt(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch('generate_scene_descriptions.urllib.request.urlopen')
|
||||
def test_returns_natural_language(self, mock_urlopen):
|
||||
"""Should return the raw response text."""
|
||||
import generate_scene_descriptions
|
||||
resp_data = {"response": "A warm sunset over rolling hills with golden light."}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(resp_data).encode()
|
||||
mock_urlopen.return_value = mock_resp
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\xff\xd8\xff\xe0" + b"\x00" * 1000)
|
||||
f.flush()
|
||||
result = generate_scene_descriptions.describe_image_sharegpt(Path(f.name))
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("sunset", result)
|
||||
|
||||
|
||||
class TestGenerateTrainingPairs(unittest.TestCase):
|
||||
"""Test generate_training_pairs() orchestration."""
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image')
|
||||
def test_jsonl_output_format(self, mock_describe):
|
||||
"""JSONL format should produce structured description objects."""
|
||||
import generate_scene_descriptions
|
||||
mock_describe.return_value = {"mood": "calm", "description": "Test"}
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
output_format="jsonl"
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertIn("description", pairs[0])
|
||||
self.assertIn("generated_at", pairs[0])
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image_sharegpt')
|
||||
def test_sharegpt_output_format(self, mock_describe):
|
||||
"""ShareGPT format should produce conversation objects."""
|
||||
import generate_scene_descriptions
|
||||
mock_describe.return_value = "A description of the scene."
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
output_format="sharegpt"
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertIn("conversations", pairs[0])
|
||||
self.assertEqual(len(pairs[0]["conversations"]), 2)
|
||||
|
||||
@patch('generate_scene_descriptions.describe_image')
|
||||
def test_dry_run_skips_api_calls(self, mock_describe):
|
||||
"""Dry run should not call describe_image."""
|
||||
import generate_scene_descriptions
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
|
||||
f.write(b"\x00" * 1000)
|
||||
f.flush()
|
||||
pairs = generate_scene_descriptions.generate_training_pairs(
|
||||
[Path(f.name)], "test-model", "http://localhost:11434",
|
||||
dry_run=True
|
||||
)
|
||||
Path(f.name).unlink()
|
||||
|
||||
mock_describe.assert_not_called()
|
||||
self.assertEqual(len(pairs), 1)
|
||||
self.assertEqual(pairs[0]["status"], "dry-run")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user