Closes #891 Tags sessions with originating profile and provides filtered access so profiles cannot see each other's data.
263 lines
7.4 KiB
Python
263 lines
7.4 KiB
Python
"""
|
|
Profile Session Isolation — #891
|
|
|
|
Tags sessions with their originating profile and provides
|
|
filtered access so profiles cannot see each other's data.
|
|
|
|
Current state: All sessions share one state.db with no profile tag.
|
|
This module adds profile tagging and filtered queries.
|
|
|
|
Usage:
|
|
from agent.profile_isolation import tag_session, get_profile_sessions, get_active_profile
|
|
|
|
# Tag a new session with the current profile
|
|
tag_session(session_id, profile_name)
|
|
|
|
# Get sessions for a specific profile
|
|
sessions = get_profile_sessions("sprint")
|
|
|
|
# Get current active profile
|
|
profile = get_active_profile()
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from datetime import datetime, timezone
|
|
|
|
HERMES_HOME = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes")))
|
|
SESSIONS_DB = HERMES_HOME / "sessions" / "state.db"
|
|
PROFILE_TAGS_FILE = HERMES_HOME / "profile_session_tags.json"
|
|
|
|
|
|
def get_active_profile() -> str:
|
|
"""Get the currently active profile name."""
|
|
config_path = HERMES_HOME / "config.yaml"
|
|
if config_path.exists():
|
|
try:
|
|
import yaml
|
|
with open(config_path) as f:
|
|
cfg = yaml.safe_load(f) or {}
|
|
return cfg.get("active_profile", "default")
|
|
except Exception:
|
|
pass
|
|
|
|
# Check environment
|
|
return os.getenv("HERMES_PROFILE", "default")
|
|
|
|
|
|
def _load_tags() -> Dict[str, str]:
|
|
"""Load session-to-profile mapping."""
|
|
if not PROFILE_TAGS_FILE.exists():
|
|
return {}
|
|
try:
|
|
with open(PROFILE_TAGS_FILE) as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _save_tags(tags: Dict[str, str]):
|
|
"""Save session-to-profile mapping."""
|
|
PROFILE_TAGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(PROFILE_TAGS_FILE, "w") as f:
|
|
json.dump(tags, f, indent=2)
|
|
|
|
|
|
def tag_session(session_id: str, profile: Optional[str] = None) -> str:
|
|
"""
|
|
Tag a session with its originating profile.
|
|
|
|
Returns the profile name used.
|
|
"""
|
|
if profile is None:
|
|
profile = get_active_profile()
|
|
|
|
tags = _load_tags()
|
|
tags[session_id] = profile
|
|
_save_tags(tags)
|
|
|
|
# Also tag in SQLite if available
|
|
_tag_session_in_db(session_id, profile)
|
|
|
|
return profile
|
|
|
|
|
|
def _tag_session_in_db(session_id: str, profile: str):
|
|
"""Add profile tag to SQLite session store."""
|
|
if not SESSIONS_DB.exists():
|
|
return
|
|
|
|
try:
|
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
|
cursor = conn.cursor()
|
|
|
|
# Check if sessions table has profile column
|
|
cursor.execute("PRAGMA table_info(sessions)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
|
|
if "profile" not in columns:
|
|
# Add profile column
|
|
cursor.execute("ALTER TABLE sessions ADD COLUMN profile TEXT DEFAULT 'default'")
|
|
|
|
# Update the session's profile
|
|
cursor.execute(
|
|
"UPDATE sessions SET profile = ? WHERE session_id = ?",
|
|
(profile, session_id)
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
except Exception:
|
|
pass # SQLite might not be available or schema differs
|
|
|
|
|
|
def get_session_profile(session_id: str) -> Optional[str]:
|
|
"""Get the profile that owns a session."""
|
|
# Check JSON tags first
|
|
tags = _load_tags()
|
|
if session_id in tags:
|
|
return tags[session_id]
|
|
|
|
# Check SQLite
|
|
if SESSIONS_DB.exists():
|
|
try:
|
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT profile FROM sessions WHERE session_id = ?",
|
|
(session_id,)
|
|
)
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
if row:
|
|
return row[0]
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def get_profile_sessions(
|
|
profile: Optional[str] = None,
|
|
limit: int = 100,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get sessions belonging to a specific profile.
|
|
|
|
Returns list of session dicts.
|
|
"""
|
|
if profile is None:
|
|
profile = get_active_profile()
|
|
|
|
sessions = []
|
|
|
|
# Get from JSON tags
|
|
tags = _load_tags()
|
|
tagged_sessions = [sid for sid, p in tags.items() if p == profile]
|
|
|
|
# Get from SQLite with profile filter
|
|
if SESSIONS_DB.exists():
|
|
try:
|
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
# Try profile column first
|
|
try:
|
|
cursor.execute(
|
|
"SELECT * FROM sessions WHERE profile = ? ORDER BY updated_at DESC LIMIT ?",
|
|
(profile, limit)
|
|
)
|
|
for row in cursor.fetchall():
|
|
sessions.append(dict(row))
|
|
except Exception:
|
|
# Fallback: filter by tagged session IDs
|
|
if tagged_sessions:
|
|
placeholders = ",".join("?" * len(tagged_sessions[:limit]))
|
|
cursor.execute(
|
|
f"SELECT * FROM sessions WHERE session_id IN ({placeholders}) ORDER BY updated_at DESC LIMIT ?",
|
|
(*tagged_sessions[:limit], limit)
|
|
)
|
|
for row in cursor.fetchall():
|
|
sessions.append(dict(row))
|
|
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
return sessions[:limit]
|
|
|
|
|
|
def filter_sessions_by_profile(
|
|
sessions: List[Dict[str, Any]],
|
|
profile: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Filter a list of sessions to only include those belonging to a profile."""
|
|
if profile is None:
|
|
profile = get_active_profile()
|
|
|
|
tags = _load_tags()
|
|
filtered = []
|
|
|
|
for session in sessions:
|
|
sid = session.get("session_id") or session.get("id")
|
|
if not sid:
|
|
continue
|
|
|
|
# Check tag
|
|
session_profile = tags.get(sid)
|
|
if session_profile is None:
|
|
# Check SQLite
|
|
session_profile = get_session_profile(sid)
|
|
|
|
if session_profile == profile or session_profile is None:
|
|
filtered.append(session)
|
|
|
|
return filtered
|
|
|
|
|
|
def get_profile_stats() -> Dict[str, Any]:
|
|
"""Get statistics about profile session distribution."""
|
|
tags = _load_tags()
|
|
|
|
profile_counts = {}
|
|
for sid, profile in tags.items():
|
|
profile_counts[profile] = profile_counts.get(profile, 0) + 1
|
|
|
|
total_tagged = len(tags)
|
|
profiles = list(profile_counts.keys())
|
|
|
|
return {
|
|
"total_tagged_sessions": total_tagged,
|
|
"profiles": profiles,
|
|
"profile_counts": profile_counts,
|
|
"active_profile": get_active_profile(),
|
|
}
|
|
|
|
|
|
def audit_untagged_sessions() -> List[str]:
|
|
"""Find sessions without a profile tag."""
|
|
if not SESSIONS_DB.exists():
|
|
return []
|
|
|
|
try:
|
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
|
cursor = conn.cursor()
|
|
|
|
# Get all session IDs
|
|
cursor.execute("SELECT session_id FROM sessions")
|
|
all_sessions = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
|
|
# Get tagged sessions
|
|
tags = _load_tags()
|
|
tagged = set(tags.keys())
|
|
|
|
# Return untagged
|
|
return list(all_sessions - tagged)
|
|
except Exception:
|
|
return []
|