""" GPU Inference Scheduler — Multi-Model Resource Management Queue-based model loading with priority lanes and VRAM budget tracking. Prevents GPU OOM crashes when multiple projects compete for VRAM. Priority lanes: 1. real-time (LPM) — highest priority, interactive 2. interactive (playground) — user-facing, medium priority 3. batch (harvester) — background, lowest priority """ import json import time import threading import logging from enum import IntEnum from pathlib import Path from typing import Dict, List, Optional, Any from dataclasses import dataclass, field, asdict logger = logging.getLogger("hermes.gpu_scheduler") class Priority(IntEnum): """Job priority levels. Lower value = higher priority.""" REALTIME = 1 # LPM, live video, interactive sessions INTERACTIVE = 2 # Playground, chat, user-facing BATCH = 3 # Harvester, overnight jobs, background @dataclass class ModelSpec: """Specification for a model and its VRAM requirements.""" name: str vram_mb: int # VRAM required in MB loader: str = "ollama" # How to load: ollama, vllm, llama_cpp, custom model_id: str = "" # Model identifier (e.g., "llama3:70b") cacheable: bool = True # Can be cached between jobs cpu_fallback: bool = True # Can fall back to CPU if GPU busy estimated_batch_ms: int = 1000 # Estimated time per batch @dataclass class InferenceJob: """A job requesting GPU inference.""" job_id: str project: str # "video_forge", "lpm", "playground", "harvester" model: ModelSpec priority: Priority batch_size: int = 1 created_at: float = field(default_factory=time.time) started_at: Optional[float] = None completed_at: Optional[float] = None status: str = "queued" # queued, loading, running, completed, failed error: Optional[str] = None use_cpu_fallback: bool = False @dataclass class GPUState: """Current GPU state.""" total_vram_mb: int = 0 used_vram_mb: int = 0 loaded_models: List[str] = field(default_factory=list) active_job: Optional[str] = None @property def available_vram_mb(self) -> int: return self.total_vram_mb - self.used_vram_mb def can_fit(self, model: ModelSpec) -> bool: return self.available_vram_mb >= model.vram_mb # Known models and their VRAM requirements MODEL_REGISTRY: Dict[str, ModelSpec] = { # Video Forge models "sd_xl": ModelSpec(name="Stable Diffusion XL", vram_mb=8192, loader="comfyui", model_id="sd_xl"), "heartmula": ModelSpec(name="HeartMuLa", vram_mb=4096, loader="custom", model_id="heartmula"), "wan2.1": ModelSpec(name="Wan2.1", vram_mb=12288, loader="custom", model_id="wan2.1"), # LPM models "lpm_video": ModelSpec(name="LPM Video Gen", vram_mb=16384, loader="custom", model_id="lpm_video"), "lpm_a2a": ModelSpec(name="LPM A2A", vram_mb=8192, loader="custom", model_id="lpm_a2a"), # Local inference (hermes) "llama3_70b": ModelSpec(name="Llama 3 70B", vram_mb=40960, loader="ollama", model_id="llama3:70b"), "llama3_8b": ModelSpec(name="Llama 3 8B", vram_mb=8192, loader="ollama", model_id="llama3:8b"), "mimo_v2_pro": ModelSpec(name="MiMo v2 Pro", vram_mb=16384, loader="ollama", model_id="xiaomi/mimo-v2-pro"), # Playground "sdxl_turbo": ModelSpec(name="SDXL Turbo", vram_mb=6144, loader="comfyui", model_id="sdxl_turbo"), } # Default VRAM budget (can be overridden) DEFAULT_VRAM_MB = 49152 # 48GB (e.g., L40S, A6000) class InferenceScheduler: """ GPU Inference Scheduler. Manages a queue of inference jobs with priority scheduling, VRAM budget tracking, and CPU fallback. """ def __init__( self, vram_budget_mb: int = DEFAULT_VRAM_MB, queue_db: str = "~/.hermes/gpu_scheduler.db", ): self.vram_budget_mb = vram_budget_mb self.queue_db = Path(queue_db).expanduser() self.queue_db.parent.mkdir(parents=True, exist_ok=True) # State self.gpu_state = GPUState(total_vram_mb=vram_budget_mb) self.job_queue: List[InferenceJob] = [] self.completed_jobs: List[InferenceJob] = [] self._lock = threading.Lock() self._running = False self._worker_thread: Optional[threading.Thread] = None # Load persisted state self._load_state() logger.info( "GPU Scheduler initialized: %dMB VRAM budget", vram_budget_mb, ) def _load_state(self): """Load state from SQLite.""" import sqlite3 conn = sqlite3.connect(str(self.queue_db)) conn.execute(""" CREATE TABLE IF NOT EXISTS jobs ( job_id TEXT PRIMARY KEY, project TEXT, model_name TEXT, priority INTEGER, batch_size INTEGER, created_at REAL, started_at REAL, completed_at REAL, status TEXT, error TEXT, use_cpu_fallback INTEGER ) """) conn.commit() # Load pending jobs rows = conn.execute( "SELECT * FROM jobs WHERE status IN ('queued', 'loading', 'running')" ).fetchall() for row in rows: model_name = row[2] model = MODEL_REGISTRY.get(model_name, ModelSpec(name=model_name, vram_mb=8192)) job = InferenceJob( job_id=row[0], project=row[1], model=model, priority=Priority(row[3]), batch_size=row[4], created_at=row[5], started_at=row[6], completed_at=row[7], status=row[8], error=row[9], use_cpu_fallback=bool(row[10]), ) self.job_queue.append(job) conn.close() logger.info("Loaded %d pending jobs", len(self.job_queue)) def _save_job(self, job: InferenceJob): """Persist job to SQLite.""" import sqlite3 conn = sqlite3.connect(str(self.queue_db)) conn.execute(""" INSERT OR REPLACE INTO jobs (job_id, project, model_name, priority, batch_size, created_at, started_at, completed_at, status, error, use_cpu_fallback) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( job.job_id, job.project, job.model.name, job.priority.value, job.batch_size, job.created_at, job.started_at, job.completed_at, job.status, job.error, int(job.use_cpu_fallback), )) conn.commit() conn.close() def submit_job( self, job_id: str, project: str, model_name: str, priority: Priority = Priority.BATCH, batch_size: int = 1, ) -> InferenceJob: """ Submit an inference job to the queue. Args: job_id: Unique job identifier project: Project name (video_forge, lpm, playground, harvester) model_name: Model name from MODEL_REGISTRY priority: Job priority batch_size: Number of items to process Returns: The created InferenceJob """ model = MODEL_REGISTRY.get(model_name) if not model: raise ValueError(f"Unknown model: {model_name}. Registered: {list(MODEL_REGISTRY.keys())}") job = InferenceJob( job_id=job_id, project=project, model=model, priority=priority, batch_size=batch_size, ) with self._lock: # Insert in priority order inserted = False for i, existing in enumerate(self.job_queue): if job.priority < existing.priority: self.job_queue.insert(i, job) inserted = True break if not inserted: self.job_queue.append(job) self._save_job(job) logger.info( "Job submitted: %s (project=%s, model=%s, priority=%s)", job_id, project, model_name, priority.name, ) return job def get_next_job(self) -> Optional[InferenceJob]: """Get the next job to process based on priority and VRAM availability.""" with self._lock: for job in self.job_queue: if job.status != "queued": continue # Check if model fits in VRAM if self.gpu_state.can_fit(job.model): return job # Check CPU fallback if job.model.cpu_fallback: job.use_cpu_fallback = True return job return None def start_job(self, job: InferenceJob) -> bool: """ Mark a job as started and load its model. Returns True if successful, False if insufficient VRAM. """ with self._lock: if not job.use_cpu_fallback: if not self.gpu_state.can_fit(job.model): logger.warning( "Insufficient VRAM for %s: need %dMB, have %dMB", job.model.name, job.model.vram_mb, self.gpu_state.available_vram_mb, ) return False # Reserve VRAM self.gpu_state.used_vram_mb += job.model.vram_mb if job.model.name not in self.gpu_state.loaded_models: self.gpu_state.loaded_models.append(job.model.name) job.status = "loading" job.started_at = time.time() self.gpu_state.active_job = job.job_id self._save_job(job) logger.info( "Job started: %s (model=%s, cpu_fallback=%s, vram_used=%dMB)", job.job_id, job.model.name, job.use_cpu_fallback, self.gpu_state.used_vram_mb, ) return True def complete_job(self, job: InferenceJob, error: str = None): """Mark a job as completed and release its VRAM.""" with self._lock: job.completed_at = time.time() job.status = "completed" if not error else "failed" job.error = error if not job.use_cpu_fallback: # Release VRAM self.gpu_state.used_vram_mb = max( 0, self.gpu_state.used_vram_mb - job.model.vram_mb, ) if self.gpu_state.active_job == job.job_id: self.gpu_state.active_job = None # Move to completed self.job_queue.remove(job) self.completed_jobs.append(job) self._save_job(job) duration = (job.completed_at - job.started_at) * 1000 if job.started_at else 0 logger.info( "Job completed: %s (status=%s, duration=%.0fms)", job.job_id, job.status, duration, ) def get_status(self) -> Dict[str, Any]: """Get scheduler status.""" with self._lock: return { "gpu": { "total_vram_mb": self.gpu_state.total_vram_mb, "used_vram_mb": self.gpu_state.used_vram_mb, "available_vram_mb": self.gpu_state.available_vram_mb, "utilization_pct": round( self.gpu_state.used_vram_mb / self.gpu_state.total_vram_mb * 100, 1 ), "loaded_models": self.gpu_state.loaded_models, "active_job": self.gpu_state.active_job, }, "queue": { "pending": len([j for j in self.job_queue if j.status == "queued"]), "loading": len([j for j in self.job_queue if j.status == "loading"]), "running": len([j for j in self.job_queue if j.status == "running"]), "by_priority": { p.name: len([j for j in self.job_queue if j.priority == p and j.status == "queued"]) for p in Priority }, }, "completed": { "total": len(self.completed_jobs), "success": len([j for j in self.completed_jobs if j.status == "completed"]), "failed": len([j for j in self.completed_jobs if j.status == "failed"]), }, } def register_model(self, name: str, spec: ModelSpec): """Register a new model.""" MODEL_REGISTRY[name] = spec logger.info("Registered model: %s (%dMB VRAM)", name, spec.vram_mb) def clear_completed(self): """Clear completed jobs from memory (keep in DB).""" with self._lock: self.completed_jobs.clear() # ============================================================================ # CLI Interface # ============================================================================ def main(): """CLI entry point for testing.""" import argparse parser = argparse.ArgumentParser(description="GPU Inference Scheduler") parser.add_argument("action", choices=["status", "submit", "list", "clear"]) parser.add_argument("--job-id", help="Job ID for submit") parser.add_argument("--project", help="Project name") parser.add_argument("--model", help="Model name") parser.add_argument("--priority", choices=["realtime", "interactive", "batch"], default="batch") parser.add_argument("--vram", type=int, default=DEFAULT_VRAM_MB, help="VRAM budget in MB") args = parser.parse_args() scheduler = InferenceScheduler(vram_budget_mb=args.vram) if args.action == "status": status = scheduler.get_status() print(json.dumps(status, indent=2)) elif args.action == "submit": if not all([args.job_id, args.project, args.model]): print("Error: --job-id, --project, and --model required for submit") return priority = Priority[args.priority.upper()] job = scheduler.submit_job(args.job_id, args.project, args.model, priority) print(f"Submitted: {job.job_id}") elif args.action == "list": print(f"Pending jobs: {len(scheduler.job_queue)}") for job in scheduler.job_queue: print(f" {job.job_id}: {job.project}/{job.model.name} [{job.priority.name}] {job.status}") elif args.action == "clear": scheduler.clear_completed() print("Cleared completed jobs from memory") if __name__ == "__main__": main()