This commit was merged in pull request #1252.
This commit is contained in:
@@ -6,6 +6,48 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAppleSiliconHelpers:
|
||||
"""Tests for is_apple_silicon() and _build_experiment_env()."""
|
||||
|
||||
def test_is_apple_silicon_true_on_arm64_darwin(self):
|
||||
from timmy.autoresearch import is_apple_silicon
|
||||
|
||||
with patch("timmy.autoresearch.platform.system", return_value="Darwin"), \
|
||||
patch("timmy.autoresearch.platform.machine", return_value="arm64"):
|
||||
assert is_apple_silicon() is True
|
||||
|
||||
def test_is_apple_silicon_false_on_linux(self):
|
||||
from timmy.autoresearch import is_apple_silicon
|
||||
|
||||
with patch("timmy.autoresearch.platform.system", return_value="Linux"), \
|
||||
patch("timmy.autoresearch.platform.machine", return_value="x86_64"):
|
||||
assert is_apple_silicon() is False
|
||||
|
||||
def test_build_env_auto_resolves_mlx_on_apple_silicon(self):
|
||||
from timmy.autoresearch import _build_experiment_env
|
||||
|
||||
with patch("timmy.autoresearch.is_apple_silicon", return_value=True):
|
||||
env = _build_experiment_env(dataset="tinystories", backend="auto")
|
||||
|
||||
assert env["AUTORESEARCH_BACKEND"] == "mlx"
|
||||
assert env["AUTORESEARCH_DATASET"] == "tinystories"
|
||||
|
||||
def test_build_env_auto_resolves_cuda_on_non_apple(self):
|
||||
from timmy.autoresearch import _build_experiment_env
|
||||
|
||||
with patch("timmy.autoresearch.is_apple_silicon", return_value=False):
|
||||
env = _build_experiment_env(dataset="openwebtext", backend="auto")
|
||||
|
||||
assert env["AUTORESEARCH_BACKEND"] == "cuda"
|
||||
assert env["AUTORESEARCH_DATASET"] == "openwebtext"
|
||||
|
||||
def test_build_env_explicit_backend_not_overridden(self):
|
||||
from timmy.autoresearch import _build_experiment_env
|
||||
|
||||
env = _build_experiment_env(dataset="tinystories", backend="cpu")
|
||||
assert env["AUTORESEARCH_BACKEND"] == "cpu"
|
||||
|
||||
|
||||
class TestPrepareExperiment:
|
||||
"""Tests for prepare_experiment()."""
|
||||
|
||||
@@ -44,6 +86,24 @@ class TestPrepareExperiment:
|
||||
|
||||
assert "failed" in result.lower()
|
||||
|
||||
def test_prepare_passes_env_to_prepare_script(self, tmp_path):
|
||||
from timmy.autoresearch import prepare_experiment
|
||||
|
||||
repo_dir = tmp_path / "autoresearch"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / "prepare.py").write_text("pass")
|
||||
|
||||
with patch("timmy.autoresearch.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
prepare_experiment(tmp_path, dataset="tinystories", backend="cpu")
|
||||
|
||||
# The prepare.py call is the second call (first is skipped since repo exists)
|
||||
prepare_call = mock_run.call_args
|
||||
assert prepare_call.kwargs.get("env") is not None or prepare_call[1].get("env") is not None
|
||||
call_kwargs = prepare_call.kwargs if prepare_call.kwargs else prepare_call[1]
|
||||
assert call_kwargs["env"]["AUTORESEARCH_DATASET"] == "tinystories"
|
||||
assert call_kwargs["env"]["AUTORESEARCH_BACKEND"] == "cpu"
|
||||
|
||||
|
||||
class TestRunExperiment:
|
||||
"""Tests for run_experiment()."""
|
||||
|
||||
Reference in New Issue
Block a user