diff --git a/tests/tools/test_gpu_scheduler.py b/tests/tools/test_gpu_scheduler.py new file mode 100644 index 000000000..a4731bf42 --- /dev/null +++ b/tests/tools/test_gpu_scheduler.py @@ -0,0 +1,256 @@ +""" +Tests for GPU Inference Scheduler. +""" + +import pytest +import tempfile +import os +from pathlib import Path + +from tools.gpu_scheduler import ( + Priority, + ModelSpec, + InferenceJob, + InferenceScheduler, + MODEL_REGISTRY, +) + + +@pytest.fixture +def scheduler(): + """Create a scheduler with a temp database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test_scheduler.db" + sched = InferenceScheduler(vram_budget_mb=32768, queue_db=str(db_path)) + yield sched + + +class TestPriority: + """Test priority ordering.""" + + def test_priority_ordering(self): + """Realtime < Interactive < Batch.""" + assert Priority.REALTIME < Priority.INTERACTIVE + assert Priority.INTERACTIVE < Priority.BATCH + + def test_priority_comparison(self): + """Lower value = higher priority.""" + assert Priority.REALTIME.value == 1 + assert Priority.INTERACTIVE.value == 2 + assert Priority.BATCH.value == 3 + + +class TestModelSpec: + """Test model specifications.""" + + def test_model_registry_has_models(self): + """Registry should have known models.""" + assert "llama3_70b" in MODEL_REGISTRY + assert "sd_xl" in MODEL_REGISTRY + assert "mimo_v2_pro" in MODEL_REGISTRY + + def test_model_vram(self): + """Models should have VRAM requirements.""" + llama = MODEL_REGISTRY["llama3_70b"] + assert llama.vram_mb > 0 + assert llama.vram_mb == 40960 # 40GB + + +class TestInferenceScheduler: + """Test the scheduler.""" + + def test_init(self, scheduler): + """Scheduler should initialize.""" + assert scheduler.vram_budget_mb == 32768 + assert scheduler.gpu_state.total_vram_mb == 32768 + assert len(scheduler.job_queue) == 0 + + def test_submit_job(self, scheduler): + """Submit a job.""" + job = scheduler.submit_job( + job_id="test-1", + project="playground", + model_name="llama3_8b", + priority=Priority.INTERACTIVE, + ) + + assert job.job_id == "test-1" + assert job.status == "queued" + assert len(scheduler.job_queue) == 1 + + def test_submit_unknown_model(self, scheduler): + """Submit with unknown model should raise.""" + with pytest.raises(ValueError, match="Unknown model"): + scheduler.submit_job( + job_id="test-1", + project="playground", + model_name="nonexistent", + ) + + def test_priority_ordering(self, scheduler): + """Jobs should be ordered by priority.""" + scheduler.submit_job("batch-1", "harvester", "llama3_8b", Priority.BATCH) + scheduler.submit_job("rt-1", "lpm", "llama3_8b", Priority.REALTIME) + scheduler.submit_job("int-1", "playground", "llama3_8b", Priority.INTERACTIVE) + + # RT should be first + assert scheduler.job_queue[0].job_id == "rt-1" + assert scheduler.job_queue[1].job_id == "int-1" + assert scheduler.job_queue[2].job_id == "batch-1" + + def test_get_next_job(self, scheduler): + """Get next job should return highest priority.""" + scheduler.submit_job("batch-1", "harvester", "llama3_8b", Priority.BATCH) + scheduler.submit_job("rt-1", "lpm", "llama3_8b", Priority.REALTIME) + + next_job = scheduler.get_next_job() + assert next_job.job_id == "rt-1" + + def test_start_job(self, scheduler): + """Start a job.""" + job = scheduler.submit_job("test-1", "playground", "llama3_8b", Priority.INTERACTIVE) + success = scheduler.start_job(job) + + assert success + assert job.status == "loading" + assert job.started_at is not None + assert scheduler.gpu_state.used_vram_mb == 8192 # llama3_8b VRAM + + def test_complete_job(self, scheduler): + """Complete a job.""" + job = scheduler.submit_job("test-1", "playground", "llama3_8b", Priority.INTERACTIVE) + scheduler.start_job(job) + scheduler.complete_job(job) + + assert job.status == "completed" + assert job.completed_at is not None + assert scheduler.gpu_state.used_vram_mb == 0 + assert len(scheduler.job_queue) == 0 + assert len(scheduler.completed_jobs) == 1 + + def test_complete_job_with_error(self, scheduler): + """Complete a job with error.""" + job = scheduler.submit_job("test-1", "playground", "llama3_8b", Priority.INTERACTIVE) + scheduler.start_job(job) + scheduler.complete_job(job, error="CUDA out of memory") + + assert job.status == "failed" + assert job.error == "CUDA out of memory" + + def test_vram_tracking(self, scheduler): + """VRAM should be tracked correctly.""" + # Submit two small jobs + job1 = scheduler.submit_job("test-1", "playground", "llama3_8b", Priority.INTERACTIVE) + job2 = scheduler.submit_job("test-2", "playground", "llama3_8b", Priority.INTERACTIVE) + + # Start first + scheduler.start_job(job1) + assert scheduler.gpu_state.used_vram_mb == 8192 + + # Start second (should work, still have room) + scheduler.start_job(job2) + assert scheduler.gpu_state.used_vram_mb == 16384 + + # Complete first + scheduler.complete_job(job1) + assert scheduler.gpu_state.used_vram_mb == 8192 + + def test_cpu_fallback(self, scheduler): + """CPU fallback when VRAM full.""" + # Fill VRAM with two 16GB models (32GB total = our budget) + job1 = scheduler.submit_job("big-1", "lpm", "mimo_v2_pro", Priority.REALTIME) + scheduler.start_job(job1) + assert scheduler.gpu_state.used_vram_mb == 16384 + + # Start another 16GB model (should work, exactly fills VRAM) + job2 = scheduler.submit_job("big-2", "playground", "mimo_v2_pro", Priority.INTERACTIVE) + scheduler.start_job(job2) + assert scheduler.gpu_state.used_vram_mb == 32768 # Full + + # Now try a third model - should get CPU fallback + job3 = scheduler.submit_job("big-3", "harvester", "mimo_v2_pro", Priority.BATCH) + next_job = scheduler.get_next_job() + + # Should get job3 with CPU fallback since VRAM is full + assert next_job.job_id == "big-3" + assert next_job.use_cpu_fallback + + def test_get_status(self, scheduler): + """Get scheduler status.""" + scheduler.submit_job("test-1", "playground", "llama3_8b", Priority.INTERACTIVE) + scheduler.submit_job("test-2", "harvester", "llama3_8b", Priority.BATCH) + + status = scheduler.get_status() + + assert status["gpu"]["total_vram_mb"] == 32768 + assert status["queue"]["pending"] == 2 + assert status["queue"]["by_priority"]["INTERACTIVE"] == 1 + assert status["queue"]["by_priority"]["BATCH"] == 1 + + def test_register_model(self, scheduler): + """Register a custom model.""" + custom = ModelSpec(name="Custom Model", vram_mb=4096) + scheduler.register_model("custom_model", custom) + + assert "custom_model" in MODEL_REGISTRY + + job = scheduler.submit_job("test-1", "playground", "custom_model") + assert job.model.vram_mb == 4096 + + +class TestCrossProjectScenarios: + """Test cross-project scenarios from the issue.""" + + def test_video_forge_batch_plus_lpm_live(self, scheduler): + """ + Video Forge batch + LPM live. + LPM should get priority, batch should queue. + """ + # Video Forge batch job + vf_job = scheduler.submit_job( + "vf-batch-1", "video_forge", "sd_xl", Priority.BATCH + ) + + # LPM live job (higher priority) + lpm_job = scheduler.submit_job( + "lpm-live-1", "lpm", "lpm_video", Priority.REALTIME + ) + + # Next job should be LPM + next_job = scheduler.get_next_job() + assert next_job.job_id == "lpm-live-1" + assert next_job.priority == Priority.REALTIME + + def test_three_video_forge_jobs(self, scheduler): + """Three Video Forge jobs should queue sequentially.""" + jobs = [] + for i in range(3): + job = scheduler.submit_job( + f"vf-{i}", "video_forge", "sd_xl", Priority.BATCH + ) + jobs.append(job) + + # Start first + scheduler.start_job(jobs[0]) + assert scheduler.gpu_state.used_vram_mb == 8192 + + # Second should queue (VRAM occupied) + next_job = scheduler.get_next_job() + assert next_job.job_id == "vf-1" + + def test_night_harvester_plus_playground(self, scheduler): + """Night harvester runs on idle cycles.""" + harvester = scheduler.submit_job( + "harvest-1", "harvester", "llama3_8b", Priority.BATCH + ) + playground = scheduler.submit_job( + "play-1", "playground", "sdxl_turbo", Priority.INTERACTIVE + ) + + # Playground should get priority + next_job = scheduler.get_next_job() + assert next_job.job_id == "play-1" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])