diff --git a/cli.py b/cli.py index 70a202d3a..4f734fad9 100755 --- a/cli.py +++ b/cli.py @@ -571,12 +571,28 @@ def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]: include_file = Path(repo_root) / ".worktreeinclude" if include_file.exists(): try: + repo_root_resolved = Path(repo_root).resolve() + wt_path_resolved = wt_path.resolve() for line in include_file.read_text().splitlines(): entry = line.strip() if not entry or entry.startswith("#"): continue src = Path(repo_root) / entry dst = wt_path / entry + # Prevent path traversal: ensure src stays within repo_root + # and dst stays within the worktree directory + try: + src_resolved = src.resolve() + dst_resolved = dst.resolve(strict=False) + except (OSError, ValueError): + logger.debug("Skipping invalid .worktreeinclude entry: %s", entry) + continue + if not str(src_resolved).startswith(str(repo_root_resolved) + os.sep) and src_resolved != repo_root_resolved: + logger.warning("Skipping .worktreeinclude entry outside repo root: %s", entry) + continue + if not str(dst_resolved).startswith(str(wt_path_resolved) + os.sep) and dst_resolved != wt_path_resolved: + logger.warning("Skipping .worktreeinclude entry that escapes worktree: %s", entry) + continue if src.is_file(): dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(str(src), str(dst)) @@ -584,7 +600,7 @@ def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]: # Symlink directories (faster, saves disk) if not dst.exists(): dst.parent.mkdir(parents=True, exist_ok=True) - os.symlink(str(src.resolve()), str(dst)) + os.symlink(str(src_resolved), str(dst)) except Exception as e: logger.debug("Error copying .worktreeinclude entries: %s", e) diff --git a/tests/test_worktree.py b/tests/test_worktree.py index f545baa39..dd24381e7 100644 --- a/tests/test_worktree.py +++ b/tests/test_worktree.py @@ -633,3 +633,75 @@ class TestSystemPromptInjection: assert info["repo_root"] in wt_note assert "isolated git worktree" in wt_note assert "commit and push" in wt_note + + +class TestWorktreeIncludePathTraversal: + """Test that .worktreeinclude entries with path traversal are rejected.""" + + def test_rejects_parent_directory_traversal(self, git_repo): + """Entries like '../../etc/passwd' must not escape the repo root.""" + import shutil as _shutil + + # Create a sensitive file outside the repo to simulate the attack + outside_file = git_repo.parent / "sensitive.txt" + outside_file.write_text("SENSITIVE DATA") + + # Create a .worktreeinclude with a traversal entry + (git_repo / ".worktreeinclude").write_text("../sensitive.txt\n") + + info = _setup_worktree(str(git_repo)) + assert info is not None + + wt_path = Path(info["path"]) + + # Replay the fixed logic from cli.py + repo_root_resolved = Path(str(git_repo)).resolve() + wt_path_resolved = wt_path.resolve() + include_file = git_repo / ".worktreeinclude" + + copied_entries = [] + for line in include_file.read_text().splitlines(): + entry = line.strip() + if not entry or entry.startswith("#"): + continue + src = Path(str(git_repo)) / entry + dst = wt_path / entry + try: + src_resolved = src.resolve() + dst_resolved = dst.resolve(strict=False) + except (OSError, ValueError): + continue + if not str(src_resolved).startswith(str(repo_root_resolved) + os.sep) and src_resolved != repo_root_resolved: + continue + if not str(dst_resolved).startswith(str(wt_path_resolved) + os.sep) and dst_resolved != wt_path_resolved: + continue + copied_entries.append(entry) + + # The traversal entry must have been skipped + assert len(copied_entries) == 0 + # The sensitive file must NOT be in the worktree + assert not (wt_path / "../sensitive.txt").resolve().is_relative_to(wt_path_resolved) + + def test_allows_valid_entries(self, git_repo): + """Normal entries within the repo should still be processed.""" + (git_repo / ".env").write_text("KEY=val") + (git_repo / ".worktreeinclude").write_text(".env\n") + + info = _setup_worktree(str(git_repo)) + assert info is not None + + repo_root_resolved = Path(str(git_repo)).resolve() + include_file = git_repo / ".worktreeinclude" + + accepted = [] + for line in include_file.read_text().splitlines(): + entry = line.strip() + if not entry or entry.startswith("#"): + continue + src = Path(str(git_repo)) / entry + src_resolved = src.resolve() + if not str(src_resolved).startswith(str(repo_root_resolved) + os.sep) and src_resolved != repo_root_resolved: + continue + accepted.append(entry) + + assert ".env" in accepted