diff --git a/scripts/vulnerability_scanner.py b/scripts/vulnerability_scanner.py new file mode 100644 index 0000000..189140c --- /dev/null +++ b/scripts/vulnerability_scanner.py @@ -0,0 +1,465 @@ +#!/usr/bin/env python3 +""" +vulnerability_scanner.py — Check Python dependencies against CVE databases (Issue #108) + +Scans requirements.txt (or any pip-compatible dependency file) and queries +the Open Source Vulnerability (OSV) database for known security issues. + +OSV API: https://api.osv.dev/v1/query (free, no auth, PyPI ecosystem supported) + +Output: + - Human-readable summary on stdout + - JSON report with full vulnerability details + - Exit code: 0 if no vulnerabilities found, 1 if critical/high found, 2 otherwise + +Usage: + python3 scripts/vulnerability_scanner.py + python3 scripts/vulnerability_scanner.py --deps requirements.txt --output json + python3 scripts/vulnerability_scanner.py --min-severity high + python3 scripts/vulnerability_scanner.py --deps requirements.txt --report-format markdown +""" + +import argparse +import json +import os +import re +import sys +import urllib.request +import urllib.error +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# --- Configuration --- + +OSV_API_URL = "https://api.osv.dev/v1/query" +DEFAULT_REQUIREMENTS_PATH = "requirements.txt" +SEVERITY_LEVELS = ["critical", "high", "medium", "low", "unknown"] + +# Map OSV severities to our buckets +CVSS_SEVERITY_MAP = { + "CRITICAL": "critical", + "HIGH": "high", + "MEDIUM": "medium", + "LOW": "low", + "NONE": "none", +} + +# --- Data Structures --- + + +@dataclass +class Vulnerability: + """A single vulnerability finding.""" + package: str + version: str + vuln_id: str + severity: str + cvss_score: Optional[float] + summary: str + details_url: str + fixed_versions: List[str] + + +@dataclass +class ScanResult: + """Results from a vulnerability scan.""" + scanned_packages: int + vulnerabilities: List[Vulnerability] + errors: List[Tuple[str, str]] # (package, error_message) + + +# --- Requirement Parsing --- + + +def parse_requirements_file(path: str) -> Dict[str, str]: + """ + Parse a requirements.txt file into {package_name: version_spec}. + + Handles: + - pkg==1.2.3 + - pkg>=1.0.0 + - pkg[extra]==1.2.3 + - -e/--editable entries (skipped) + - -r inclusions (recursive, limited depth) + - comments and blank lines + """ + packages = {} + processed_includes = set() + + def parse_line(line: str, filename: str, depth: int = 0) -> None: + if depth > 3: + print(f"WARNING: Max include depth exceeded in {filename}", file=sys.stderr) + return + + line = line.strip() + if not line or line.startswith('#'): + return + + # Handle -r or --requirement includes + if line.startswith('-r ') or line.startswith('--requirement '): + if depth >= 3: + return + include_path = line.split(None, 1)[1].strip() + # Resolve relative to current file's directory + base_dir = os.path.dirname(os.path.abspath(filename)) + full_path = os.path.join(base_dir, include_path) + if full_path not in processed_includes: + processed_includes.add(full_path) + try: + with open(full_path, 'r', encoding='utf-8') as f: + for incl_line in f: + parse_line(incl_line, full_path, depth + 1) + except FileNotFoundError: + print(f"WARNING: Could not read included file: {full_path}", file=sys.stderr) + return + + # Skip editable installs and other flags + if line.startswith('-e ') or line.startswith('--editable ') or line.startswith('-'): + return + + # Extract package name and version spec + # Handles: pkg==1.2.3, pkg>=1.0, pkg[extra]==1.2.3, pkg ~= 1.0 + match = re.match( + r'^([a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?)(\s*[[,{])?.*?((==|>=|<=|~=|!=|===)\s*([^\s;#]+))?', + line + ) + if not match: + # Try simpler: name at start before any comparison + simple = re.match(r'^([a-zA-Z0-9][-a-zA-Z0-9_.]*)', line) + if simple: + pkg = simple.group(1).lower() + packages[pkg] = "" + return + + pkg_name = match.group(1).lower() + # Strip extras like django[argon2] -> django + pkg_name = re.sub(r'\[.*?\]', '', pkg_name).strip() + + version = "" + if match.group(5): # comparison operator + version + version = match.group(5) + match.group(6) + + packages[pkg_name] = version + + # Read and parse the file + try: + with open(path, 'r', encoding='utf-8') as f: + for line in f: + parse_line(line, path, 0) + except FileNotFoundError: + print(f"ERROR: Requirements file not found: {path}", file=sys.stderr) + sys.exit(1) + + return packages + + +# --- OSV API Queries --- + + +def query_osv(package: str, version: str) -> List[dict]: + """ + Query the OSV API for vulnerabilities affecting a specific package version. + + Returns list of vulnerability dicts (raw API response) or empty list on error. + """ + # Normalize version spec for OSV query + # OSV expects a specific version, not a range. We query for the exact version + # if available, otherwise we query without version to get all vulns for the package + # and let the caller filter. + query_version = version if re.match(r'^[0-9]', version) else None + + payload = { + "package": { + "name": package, + "ecosystem": "PyPI" + } + } + if query_version: + payload["version"] = query_version + + data = json.dumps(payload).encode('utf-8') + req = urllib.request.Request( + OSV_API_URL, + data=data, + headers={'Content-Type': 'application/json'}, + method='POST' + ) + + try: + with urllib.request.urlopen(req, timeout=15) as response: + result = json.loads(response.read().decode('utf-8')) + return result.get('vulns', []) + result.get('vulnerabilities', []) + except urllib.error.HTTPError as e: + if e.code == 404: + return [] # No vulnerabilities found + print(f"WARNING: OSV query failed for {package}: HTTP {e.code}", file=sys.stderr) + except (urllib.error.URLError, json.JSONDecodeError, TimeoutError) as e: + print(f"WARNING: OSV query failed for {package}: {e}", file=sys.stderr) + + return [] + + +def parse_osv_vuln(raw_vulns: List[dict], package: str, version_spec: str) -> List[Vulnerability]: + """ + Parse raw OSV API responses into Vulnerability objects. + """ + vulns = [] + for v in raw_vulns: + vuln_id = v.get('id', 'UNKNOWN') + summary = v.get('summary', 'No summary provided.') + + # Severity from CVSS or ecosystem-specific + severity = "unknown" + cvss_score = None + if 'severity' in v: + for sev_info in v['severity']: + if sev_info.get('type') == 'CVSS_V3': + score = sev_info.get('score', '') + if isinstance(score, dict): + cvss_score = score.get('baseScore') + sev_str = score.get('baseSeverity', '').upper() + severity = CVSS_SEVERITY_MAP.get(sev_str, 'unknown') + break + elif sev_info.get('type') == 'CVSS_V2': + # Fallback + score = sev_info.get('score', '') + if isinstance(score, dict): + cvss_score = score.get('baseScore') + sev_str = sev_info.get('type', '').upper() + severity = "unknown" + + # Affected packages/ranges + affected = v.get('affected', []) + fixed_versions = [] + for aff in affected: + for r in aff.get('ranges', []): + for event in r.get('events', []): + if event.get('introduced'): + # We have the version, fixed would be in 'fixed' events + pass + if event.get('fixed'): + fixed_versions.append(event['fixed']) + + # Build details URL + details_url = f"https://osv.dev/vulnerability/{vuln_id}" + + vuln = Vulnerability( + package=package, + version=version_spec, + vuln_id=vuln_id, + severity=severity, + cvss_score=cvss_score, + summary=summary, + details_url=details_url, + fixed_versions=list(set(fixed_versions)) + ) + vulns.append(vuln) + + return vulns + + +# --- Filtering & Reporting --- + + +def filter_by_severity(vulns: List[Vulnerability], min_severity: str) -> List[Vulnerability]: + """Filter vulnerabilities to include only those at or above the given severity.""" + if min_severity.lower() not in SEVERITY_LEVELS: + return vulns # No filtering if invalid + + min_idx = SEVERITY_LEVELS.index(min_severity.lower()) + filtered = [] + for v in vulns: + sev_idx = SEVERITY_LEVELS.index(v.severity.lower()) + if sev_idx <= min_idx: # lower index = more severe + filtered.append(v) + return filtered + + +def generate_text_report(result: ScanResult, packages: Dict[str, str]) -> str: + """Generate human-readable text report.""" + lines = [] + lines.append("=" * 60) + lines.append("Vulnerability Scan Report") + lines.append("=" * 60) + lines.append(f"Packages scanned: {result.scanned_packages}") + lines.append(f"Vulnerabilities found: {len(result.vulnerabilities)}") + + if result.errors: + lines.append(f"Errors: {len(result.errors)}") + + # Group by severity + by_severity: Dict[str, List[Vulnerability]] = {} + for v in result.vulnerabilities: + by_severity.setdefault(v.severity.upper(), []).append(v) + + for sev in ["CRITICAL", "HIGH", "MEDIUM", "LOW", "UNKNOWN"]: + vuln_list = by_severity.get(sev, []) + if vuln_list: + lines.append(f"\n{sev}: {len(vuln_list)}") + for v in vuln_list: + lines.append(f" [{v.package} {packages.get(v.package, '')}] {v.vuln_id}") + lines.append(f" {v.summary[:80]}") + if v.cvss_score: + lines.append(f" CVSS: {v.cvss_score}") + if v.fixed_versions: + lines.append(f" Fixed in: {', '.join(v.fixed_versions[:3])}") + lines.append(f" {v.details_url}") + + if result.errors: + lines.append("\nERRORS:") + for pkg, err in result.errors[:10]: + lines.append(f" {pkg}: {err}") + + lines.append("\n" + "=" * 60) + return "\n".join(lines) + + +def generate_json_report(result: ScanResult, packages: Dict[str, str]) -> str: + """Generate JSON report.""" + report = { + "scanned_packages": result.scanned_packages, + "vulnerabilities": [ + { + "package": v.package, + "version_spec": packages.get(v.package, v.version), + "vulnerability_id": v.vuln_id, + "severity": v.severity, + "cvss_score": v.cvss_score, + "summary": v.summary, + "details_url": v.details_url, + "fixed_versions": v.fixed_versions, + } + for v in result.vulnerabilities + ], + "errors": [{"package": p, "error": e} for p, e in result.errors], + } + return json.dumps(report, indent=2) + + +# --- Main Orchestration --- + + +def run_scan( + deps_path: str, + min_severity: str = "low", + query_osv_api: bool = True +) -> ScanResult: + """ + Execute the full vulnerability scan pipeline. + + Args: + deps_path: Path to requirements-style file + min_severity: Minimum severity to include in results + query_osv_api: If False, skip API calls (for testing/dry-run) + + Returns: + ScanResult with all findings + """ + # 1. Parse dependencies + packages = parse_requirements_file(deps_path) + if not packages: + return ScanResult(scanned_packages=0, vulnerabilities=[], errors=[]) + + # 2. Query OSV for each package + vulnerabilities: List[Vulnerability] = [] + errors: List[Tuple[str, str]] = [] + + for pkg, version_spec in packages.items(): + if not query_osv_api: + continue + + raw_vulns = query_osv(pkg, version_spec or "") + if raw_vulns: + parsed = parse_osv_vuln(raw_vulns, pkg, version_spec or "") + vulnerabilities.extend(parsed) + + # 3. Filter by severity + filtered = filter_by_severity(vulnerabilities, min_severity) + + # 4. Build result + return ScanResult( + scanned_packages=len(packages), + vulnerabilities=filtered, + errors=errors + ) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Scan Python dependencies for known vulnerabilities using OSV database" + ) + parser.add_argument( + '--deps', '-d', + default=DEFAULT_REQUIREMENTS_PATH, + help='Path to requirements.txt (default: requirements.txt)' + ) + parser.add_argument( + '--output', '-o', + choices=['text', 'json', 'markdown'], + default='text', + help='Output format (default: text)' + ) + parser.add_argument( + '--min-severity', + default='low', + choices=SEVERITY_LEVELS, + help='Minimum severity to report (default: low — report all)' + ) + parser.add_argument( + '--json', + action='store_true', + help='Output JSON (shorthand for --output json)' + ) + parser.add_argument( + '--quiet', '-q', + action='store_true', + help='Only print summary, skip detailed vulnerability list' + ) + + args = parser.parse_args() + + # Update output if --json flag is used + if args.json: + args.output = 'json' + + # Run the scan + result = run_scan(args.deps, args.min_severity, query_osv_api=True) + + # Output + if args.output == 'json': + print(generate_json_report(result, parse_requirements_file(args.deps))) + elif args.output == 'markdown': + # Simple markdown table + print("# Vulnerability Scan Report\n") + print(f"**Packages scanned:** {result.scanned_packages}") + print(f"**Vulnerabilities:** {len(result.vulnerabilities)}\n") + if result.vulnerabilities: + print("| Severity | Package | Version | Vuln ID | Summary |") + print("|----------|---------|---------|---------|---------|") + for v in result.vulnerabilities: + print(f"| {v.severity.upper()} | {v.package} | {v.version} | [{v.vuln_id}]({v.details_url}) | {v.summary[:50]} |") + print("\n") + else: + # text (default) + if not args.quiet: + print(generate_text_report(result, parse_requirements_file(args.deps))) + else: + crit = sum(1 for v in result.vulnerabilities if v.severity == 'critical') + high = sum(1 for v in result.vulnerabilities if v.severity == 'high') + med = sum(1 for v in result.vulnerabilities if v.severity == 'medium') + print(f"CRITICAL={crit} HIGH={high} MEDIUM={med} TOTAL={len(result.vulnerabilities)}") + + # Exit code logic: 0 if no vulns at min_severity+, 1 if critical/high found, 2 for other vulns + has_critical_high = any(v.severity in ('critical', 'high') for v in result.vulnerabilities) + has_other = any(v.severity not in ('critical', 'high') for v in result.vulnerabilities) + + if has_critical_high: + return 1 + elif has_other: + return 2 + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/tests/test_vulnerability_scanner.py b/tests/test_vulnerability_scanner.py new file mode 100644 index 0000000..ebba4b7 --- /dev/null +++ b/tests/test_vulnerability_scanner.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +"""Tests for scripts/vulnerability_scanner.py — 10 tests.""" + +import json +import os +import sys +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +sys.path.insert(0, os.path.join(os.path.dirname(__file__) or ".", "..")) +import importlib.util +spec = importlib.util.spec_from_file_location( + "vulnerability_scanner", + os.path.join(os.path.dirname(__file__) or ".", "..", "scripts", "vulnerability_scanner.py")) +mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(mod) + +parse_requirements_file = mod.parse_requirements_file +query_osv = mod.query_osv +parse_osv_vuln = mod.parse_osv_vuln +filter_by_severity = mod.filter_by_severity +Vulnerability = mod.Vulnerability + + +# --- Test Data --- + +SAMPLE_OSV_RESPONSE = [ + { + "id": "GHSA-xxxx-xxxx-xxxx", + "summary": " Arbitrary code execution in django", + "severity": [{"type": "CVSS_V3", "score": {"baseScore": 9.8, "baseSeverity": "CRITICAL"}}], + "affected": [{ + "ranges": [{ + "events": [ + {"introduced": "0"}, + {"fixed": "3.2.14"} + ] + }] + }] + }, + { + "id": "PYSEC-2024-1234", + "summary": " Denial of service in cryptography", + "severity": [{"type": "CVSS_V3", "score": {"baseScore": 5.3, "baseSeverity": "MEDIUM"}}], + "affected": [{ + "ranges": [{ + "events": [ + {"introduced": "0"}, + {"fixed": "42.0.0"} + ] + }] + }] + } +] + + +# --- Tests --- + + +def test_parse_requirements_simple(): + """Should parse a simple requirements file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("django==4.2.0\n") + f.write("requests>=2.28.0\n") + f.write("click~=8.0\n") + f.flush() + pkgs = parse_requirements_file(f.name) + os.unlink(f.name) + + assert "django" in pkgs + assert pkgs["django"] == "==4.2.0" + assert "requests" in pkgs + assert pkgs["requests"] == ">=2.28.0" + assert "click" in pkgs + print("PASS: test_parse_requirements_simple") + + +def test_parse_requirements_extras_and_comments(): + """Should skip comments, blank lines, and handle package extras.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("# This is a comment\n") + f.write("django[argon2]==4.2.0\n") + f.write("\n") + f.write(" requests >=2.28.0 # inline comment\n") + f.flush() + pkgs = parse_requirements_file(f.name) + os.unlink(f.name) + + assert "django" in pkgs + assert pkgs["django"] == "==4.2.0" + assert "requests" in pkgs + # Version should capture the comparison + assert ">=" in pkgs["requests"] + print("PASS: test_parse_requirements_extras_and_comments") + + +def test_parse_requirements_include_recursive(): + """Should follow -r includes up to depth 3.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Main requirements.txt + main = os.path.join(tmpdir, "requirements.txt") + with open(main, 'w') as f: + f.write("django==4.2.0\n") + f.write("-r base.txt\n") + + # base.txt + base = os.path.join(tmpdir, "base.txt") + with open(base, 'w') as f: + f.write("requests>=2.28.0\n") + f.write("-r deep.txt\n") + + # deep.txt + deep = os.path.join(tmpdir, "deep.txt") + with open(deep, 'w') as f: + f.write("click~=8.0\n") + + pkgs = parse_requirements_file(main) + + assert "django" in pkgs + assert "requests" in pkgs + assert "click" in pkgs + print("PASS: test_parse_requirements_include_recursive") + + +def test_parse_requirements_skip_editable(): + """Should skip -e editable installs and other flags.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("-e git+https://github.com/user/repo.git@branch#egg=package\n") + f.write("--index-url https://pypi.org/simple\n") + f.write("django==4.2.0\n") + f.flush() + pkgs = parse_requirements_file(f.name) + os.unlink(f.name) + + assert "django" in pkgs + assert "package" not in pkgs # should not pick up editable name + print("PASS: test_parse_requirements_skip_editable") + + +def test_parse_requirements_nonexistent(): + """Should exit with error on missing file.""" + with patch('sys.exit') as mock_exit: + pkgs = parse_requirements_file("/nonexistent/requirements.txt") + mock_exit.assert_called_once_with(1) + print("PASS: test_parse_requirements_nonexistent") + + +def test_filter_by_severity(): + """Should filter vulnerabilities by severity threshold.""" + vulns = [ + Vulnerability("pkg1", "==1.0", "V1", "critical", 9.8, "summary", "url", []), + Vulnerability("pkg2", "==2.0", "V2", "high", 7.5, "summary", "url", []), + Vulnerability("pkg3", "==3.0", "V3", "medium", 5.0, "summary", "url", []), + Vulnerability("pkg4", "==4.0", "V4", "low", 2.0, "summary", "url", []), + ] + + # min_severity: low includes all + filtered = filter_by_severity(vulns, "low") + assert len(filtered) == 4 + + # min_severity: medium excludes low + filtered = filter_by_severity(vulns, "medium") + assert len(filtered) == 3 + assert all(v.severity in ("critical", "high", "medium") for v in filtered) + + # min_severity: high excludes medium + low + filtered = filter_by_severity(vulns, "high") + assert len(filtered) == 2 + + # min_severity: critical only + filtered = filter_by_severity(vulns, "critical") + assert len(filtered) == 1 + + print("PASS: test_filter_by_severity") + + +def test_parse_osv_vuln(): + """Should parse OSV API response correctly.""" + parsed = parse_osv_vuln(SAMPLE_OSV_RESPONSE, "django", "==4.2.0") + + assert len(parsed) == 2 + assert parsed[0].package == "django" + assert parsed[0].vuln_id == "GHSA-xxxx-xxxx-xxxx" + assert parsed[0].severity == "critical" + assert parsed[0].cvss_score == 9.8 + assert parsed[1].severity == "medium" + assert parsed[1].cvss_score == 5.3 + print("PASS: test_parse_osv_vuln") + + +def test_parse_osv_vuln_empty(): + """Should handle empty OSV response.""" + parsed = parse_osv_vuln([], "django", "==4.2.0") + assert parsed == [] + print("PASS: test_parse_osv_vuln_empty") + + +def test_query_osv_network_success(): + """Should successfully query OSV API for a real known vulnerable package.""" + # Query for an old django version that likely has known CVEs + # This test actually hits the network — tagged as integration + vulns = query_osv("django", "==3.2.0") + # We don't assert specific results since vulns change over time + # But we assert the function returns a list and doesn't error + assert isinstance(vulns, list) + print("PASS: test_query_osv_network_success") + + +def test_query_osv_404_no_vulns(): + """OSV returns empty list for packages with no vulns (404-like).""" + # Mock a 404 response from OSV API + with patch('urllib.request.urlopen') as mock_urlopen: + mock_response = MagicMock() + mock_response.read.return_value = b'{"vulns": []}' + mock_response.__enter__ = lambda self: self + mock_response.__exit__ = lambda self, *args: None + mock_urlopen.return_value = mock_response + + result = query_osv("nonexistent-package-xyz123", "==1.0.0") + assert result == [] + print("PASS: test_query_osv_404_no_vulns") + + +if __name__ == '__main__': + # Run all tests + test_parse_requirements_simple() + test_parse_requirements_extras_and_comments() + test_parse_requirements_include_recursive() + test_parse_requirements_skip_editable() + test_parse_requirements_nonexistent() + test_filter_by_severity() + test_parse_osv_vuln() + test_parse_osv_vuln_empty() + test_query_osv_network_success() + test_query_osv_404_no_vulns() + print("\nAll tests passed.")