"""Tests for Morrowind command log and training export pipeline.""" from datetime import UTC, datetime, timedelta from pathlib import Path import pytest from src.infrastructure.morrowind.command_log import CommandLog, CommandLogger from src.infrastructure.morrowind.schemas import ( CommandInput, CommandType, PerceptionOutput, ) from src.infrastructure.morrowind.training_export import TrainingExporter # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- NOW = datetime(2026, 3, 21, 14, 30, 0, tzinfo=UTC) def _make_perception(**overrides) -> PerceptionOutput: defaults = { "timestamp": NOW, "agent_id": "timmy", "location": {"cell": "Balmora", "x": 1024.5, "y": -512.3, "z": 64.0}, "health": {"current": 85, "max": 100}, } defaults.update(overrides) return PerceptionOutput(**defaults) def _make_command(**overrides) -> CommandInput: defaults = { "timestamp": NOW, "agent_id": "timmy", "command": "move_to", "params": {"target_x": 1050.0}, "reasoning": "Moving closer to quest target.", } defaults.update(overrides) return CommandInput(**defaults) @pytest.fixture def logger(tmp_path: Path) -> CommandLogger: """CommandLogger backed by an in-memory SQLite DB.""" db_path = tmp_path / "test.db" return CommandLogger(db_url=f"sqlite:///{db_path}") @pytest.fixture def exporter(logger: CommandLogger) -> TrainingExporter: return TrainingExporter(logger) # --------------------------------------------------------------------------- # CommandLogger — log_command # --------------------------------------------------------------------------- class TestLogCommand: def test_basic_log(self, logger: CommandLogger): cmd = _make_command() row_id = logger.log_command(cmd) assert row_id >= 1 def test_log_with_perception(self, logger: CommandLogger): cmd = _make_command() perception = _make_perception() row_id = logger.log_command(cmd, perception=perception) assert row_id >= 1 results = logger.query(limit=1) assert len(results) == 1 assert results[0]["cell"] == "Balmora" assert results[0]["perception_snapshot"]["location"]["cell"] == "Balmora" def test_log_with_outcome(self, logger: CommandLogger): cmd = _make_command() row_id = logger.log_command(cmd, outcome="success: arrived at destination") results = logger.query(limit=1) assert results[0]["outcome"] == "success: arrived at destination" def test_log_preserves_episode_id(self, logger: CommandLogger): cmd = _make_command(episode_id="ep_test_001") logger.log_command(cmd) results = logger.query(episode_id="ep_test_001") assert len(results) == 1 assert results[0]["episode_id"] == "ep_test_001" # --------------------------------------------------------------------------- # CommandLogger — query # --------------------------------------------------------------------------- class TestQuery: def test_filter_by_command_type(self, logger: CommandLogger): logger.log_command(_make_command(command="move_to")) logger.log_command(_make_command(command="noop")) logger.log_command(_make_command(command="move_to")) results = logger.query(command_type="move_to") assert len(results) == 2 assert all(r["command"] == "move_to" for r in results) def test_filter_by_cell(self, logger: CommandLogger): p1 = _make_perception(location={"cell": "Balmora", "x": 0, "y": 0, "z": 0}) p2 = _make_perception(location={"cell": "Vivec", "x": 0, "y": 0, "z": 0}) logger.log_command(_make_command(), perception=p1) logger.log_command(_make_command(), perception=p2) results = logger.query(cell="Vivec") assert len(results) == 1 assert results[0]["cell"] == "Vivec" def test_filter_by_time_range(self, logger: CommandLogger): t1 = NOW - timedelta(hours=2) t2 = NOW - timedelta(hours=1) t3 = NOW logger.log_command(_make_command(timestamp=t1.isoformat())) logger.log_command(_make_command(timestamp=t2.isoformat())) logger.log_command(_make_command(timestamp=t3.isoformat())) results = logger.query(since=NOW - timedelta(hours=1, minutes=30), until=NOW) assert len(results) == 2 def test_limit_and_offset(self, logger: CommandLogger): for i in range(5): logger.log_command(_make_command()) results = logger.query(limit=2, offset=0) assert len(results) == 2 results = logger.query(limit=10, offset=3) assert len(results) == 2 def test_empty_query(self, logger: CommandLogger): results = logger.query() assert results == [] # --------------------------------------------------------------------------- # CommandLogger — export_training_data (JSONL) # --------------------------------------------------------------------------- class TestExportTrainingData: def test_basic_export(self, logger: CommandLogger, tmp_path: Path): perception = _make_perception() for _ in range(3): logger.log_command(_make_command(), perception=perception) output = tmp_path / "train.jsonl" count = logger.export_training_data(output) assert count == 3 assert output.exists() import json lines = output.read_text().strip().split("\n") assert len(lines) == 3 record = json.loads(lines[0]) assert "input" in record assert "output" in record assert record["output"]["command"] == "move_to" def test_export_filter_by_episode(self, logger: CommandLogger, tmp_path: Path): logger.log_command(_make_command(episode_id="ep_a"), perception=_make_perception()) logger.log_command(_make_command(episode_id="ep_b"), perception=_make_perception()) output = tmp_path / "ep_a.jsonl" count = logger.export_training_data(output, episode_id="ep_a") assert count == 1 # --------------------------------------------------------------------------- # CommandLogger — storage management # --------------------------------------------------------------------------- class TestStorageManagement: def test_count(self, logger: CommandLogger): assert logger.count() == 0 logger.log_command(_make_command()) logger.log_command(_make_command()) assert logger.count() == 2 def test_rotate_old_entries(self, logger: CommandLogger): old_time = NOW - timedelta(days=100) logger.log_command(_make_command(timestamp=old_time.isoformat())) logger.log_command(_make_command(timestamp=NOW.isoformat())) deleted = logger.rotate(max_age_days=90) assert deleted == 1 assert logger.count() == 1 def test_rotate_nothing_to_delete(self, logger: CommandLogger): logger.log_command(_make_command(timestamp=NOW.isoformat())) deleted = logger.rotate(max_age_days=1) assert deleted == 0 # --------------------------------------------------------------------------- # TrainingExporter — chat format # --------------------------------------------------------------------------- class TestTrainingExporterChat: def test_chat_format_export( self, logger: CommandLogger, exporter: TrainingExporter, tmp_path: Path ): perception = _make_perception() for _ in range(3): logger.log_command(_make_command(), perception=perception) output = tmp_path / "chat.jsonl" stats = exporter.export_chat_format(output) assert stats.total_records == 3 assert stats.format == "chat_completion" import json lines = output.read_text().strip().split("\n") record = json.loads(lines[0]) assert record["messages"][0]["role"] == "system" assert record["messages"][1]["role"] == "user" assert record["messages"][2]["role"] == "assistant" # --------------------------------------------------------------------------- # TrainingExporter — episode sequences # --------------------------------------------------------------------------- class TestTrainingExporterEpisodes: def test_episode_export( self, logger: CommandLogger, exporter: TrainingExporter, tmp_path: Path ): perception = _make_perception() for i in range(5): logger.log_command( _make_command(episode_id="ep_test"), perception=perception, ) output_dir = tmp_path / "episodes" stats = exporter.export_episode_sequences(output_dir, min_length=3) assert stats.episodes_exported == 1 assert stats.total_records == 5 assert (output_dir / "ep_test.jsonl").exists() def test_short_episodes_skipped( self, logger: CommandLogger, exporter: TrainingExporter, tmp_path: Path ): perception = _make_perception() logger.log_command(_make_command(episode_id="short"), perception=perception) output_dir = tmp_path / "episodes" stats = exporter.export_episode_sequences(output_dir, min_length=3) assert stats.episodes_exported == 0 assert stats.skipped_records == 1