Files
timmy-config/wizards/allegro-primus/knowledge/knowledge_graph.py

520 lines
17 KiB
Python
Raw Normal View History

2026-03-31 20:02:01 +00:00
"""
AP Knowledge Base - Knowledge Graph
Relationship tracking and memory lineage for the Allegro-Primus knowledge base.
"""
import json
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Set, Tuple
from collections import defaultdict
import networkx as nx # type: ignore
from memory_types import MemoryEntry, MemoryRelationship, RelationshipType
class KnowledgeGraph:
"""
Knowledge graph for tracking relationships between memories.
Provides:
- Link related memories
- Track memory lineage (derived_from, supersedes)
- Query by relationship type
- Find memory clusters
- Detect conflicts
"""
def __init__(self, base_dir: Path):
self.base_dir = Path(base_dir)
self.relationships_path = self.base_dir / "relationships.json"
self.lineage_path = self.base_dir / "lineage.json"
self._graph: Optional[nx.DiGraph] = None
self._relationships: Dict[str, List[MemoryRelationship]] = defaultdict(list)
self._lineage: Dict[str, Dict] = {}
self._load()
def _load(self):
"""Load relationships and lineage from disk."""
# Load relationships
if self.relationships_path.exists():
try:
data = json.loads(self.relationships_path.read_text())
for rel_data in data:
rel = MemoryRelationship(
source_id=rel_data['source_id'],
target_id=rel_data['target_id'],
type=RelationshipType(rel_data['type']),
description=rel_data.get('description', ''),
created_at=datetime.fromisoformat(rel_data.get('created_at', datetime.utcnow().isoformat())),
strength=rel_data.get('strength', 1.0),
metadata=rel_data.get('metadata', {})
)
self._relationships[rel.source_id].append(rel)
except Exception as e:
print(f"Warning: Failed to load relationships: {e}")
# Load lineage
if self.lineage_path.exists():
try:
self._lineage = json.loads(self.lineage_path.read_text())
except Exception as e:
print(f"Warning: Failed to load lineage: {e}")
def _save(self):
"""Save relationships and lineage to disk."""
# Save relationships
data = []
for rels in self._relationships.values():
for rel in rels:
data.append({
'source_id': rel.source_id,
'target_id': rel.target_id,
'type': rel.type.value,
'description': rel.description,
'created_at': rel.created_at.isoformat(),
'strength': rel.strength,
'metadata': rel.metadata
})
self.relationships_path.write_text(json.dumps(data, indent=2), encoding='utf-8')
# Save lineage
self.lineage_path.write_text(json.dumps(self._lineage, indent=2), encoding='utf-8')
def _ensure_graph(self) -> nx.DiGraph:
"""Ensure the NetworkX graph is built."""
if self._graph is None:
self._graph = nx.DiGraph()
# Add all relationships as edges
for rels in self._relationships.values():
for rel in rels:
self._graph.add_edge(
rel.source_id,
rel.target_id,
type=rel.type.value,
strength=rel.strength,
description=rel.description
)
return self._graph
def add_relationship(
self,
source_id: str,
target_id: str,
rel_type: RelationshipType,
description: str = "",
strength: float = 1.0,
metadata: Optional[Dict] = None
) -> MemoryRelationship:
"""
Add a relationship between two memories.
Args:
source_id: ID of the source memory
target_id: ID of the target memory
rel_type: Type of relationship
description: Human-readable description
strength: Relationship strength (0.0-1.0)
metadata: Additional metadata
Returns:
The created relationship
"""
rel = MemoryRelationship(
source_id=source_id,
target_id=target_id,
type=rel_type,
description=description,
strength=strength,
metadata=metadata or {}
)
self._relationships[source_id].append(rel)
# Invalidate graph cache
self._graph = None
# Save to disk
self._save()
return rel
def remove_relationship(self, source_id: str, target_id: str, rel_type: Optional[RelationshipType] = None) -> bool:
"""
Remove a relationship between memories.
Returns:
True if a relationship was removed
"""
if source_id not in self._relationships:
return False
original_count = len(self._relationships[source_id])
if rel_type:
self._relationships[source_id] = [
r for r in self._relationships[source_id]
if not (r.target_id == target_id and r.type == rel_type)
]
else:
self._relationships[source_id] = [
r for r in self._relationships[source_id]
if r.target_id != target_id
]
removed = len(self._relationships[source_id]) < original_count
if removed:
self._graph = None
self._save()
return removed
def get_relationships(
self,
memory_id: str,
rel_type: Optional[RelationshipType] = None,
direction: str = "outgoing" # "outgoing", "incoming", "both"
) -> List[MemoryRelationship]:
"""
Get relationships for a memory.
Args:
memory_id: The memory ID
rel_type: Filter by relationship type
direction: Which direction to look (outgoing, incoming, both)
Returns:
List of matching relationships
"""
results = []
if direction in ("outgoing", "both"):
for rel in self._relationships.get(memory_id, []):
if rel_type is None or rel.type == rel_type:
results.append(rel)
if direction in ("incoming", "both"):
for source_id, rels in self._relationships.items():
if source_id == memory_id:
continue
for rel in rels:
if rel.target_id == memory_id:
if rel_type is None or rel.type == rel_type:
# Create a reversed view
results.append(MemoryRelationship(
source_id=rel.target_id,
target_id=rel.source_id,
type=rel.type,
description=rel.description,
created_at=rel.created_at,
strength=rel.strength,
metadata=rel.metadata
))
return results
def find_related(
self,
memory_id: str,
max_depth: int = 2,
min_strength: float = 0.0
) -> Dict[str, List[Tuple[str, float]]]:
"""
Find all memories related to a given memory.
Args:
memory_id: Starting memory ID
max_depth: Maximum traversal depth
min_strength: Minimum relationship strength
Returns:
Dict mapping depth to list of (memory_id, cumulative_strength) tuples
"""
graph = self._ensure_graph()
if memory_id not in graph:
return {}
results: Dict[int, List[Tuple[str, float]]] = {0: [(memory_id, 1.0)]}
visited = {memory_id}
for depth in range(1, max_depth + 1):
results[depth] = []
for prev_id, prev_strength in results[depth - 1]:
for neighbor in graph.neighbors(prev_id):
if neighbor in visited:
continue
edge_data = graph.get_edge_data(prev_id, neighbor)
if edge_data and edge_data.get('strength', 1.0) >= min_strength:
strength = prev_strength * edge_data.get('strength', 1.0)
results[depth].append((neighbor, strength))
visited.add(neighbor)
if not results[depth]:
break
return results
def get_lineage(self, memory_id: str) -> Dict:
"""
Get the full lineage of a memory (what it was derived from).
Returns:
Dict with 'ancestors', 'descendants', and 'superseded_by' lists
"""
ancestors = []
descendants = []
superseded_by = None
# Walk up the derivation chain
current = memory_id
visited = set()
while current and current not in visited:
visited.add(current)
rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "outgoing")
if rels:
parent = rels[0].target_id
ancestors.append(parent)
current = parent
else:
break
# Walk down the derivation chain
visited.clear()
current = memory_id
while current and current not in visited:
visited.add(current)
rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "incoming")
for rel in rels:
descendants.append(rel.source_id)
# Check if superseded
sup_rel = self.get_relationships(memory_id, RelationshipType.SUPERSEDES, "incoming")
if sup_rel:
superseded_by = sup_rel[0].source_id
return {
'memory_id': memory_id,
'ancestors': ancestors,
'descendants': descendants,
'superseded_by': superseded_by
}
def record_derivation(self, new_memory_id: str, parent_memory_id: str, notes: str = ""):
"""Record that a memory was derived from another."""
self.add_relationship(
new_memory_id,
parent_memory_id,
RelationshipType.DERIVED_FROM,
description=notes or f"Derived from {parent_memory_id}"
)
# Update lineage tracking
if parent_memory_id in self._lineage:
generation = self._lineage[parent_memory_id].get('generation', 0) + 1
else:
generation = 1
self._lineage[new_memory_id] = {
'parent': parent_memory_id,
'generation': generation,
'created_at': datetime.utcnow().isoformat(),
'notes': notes
}
self._save()
def record_supersession(self, new_memory_id: str, old_memory_id: str, reason: str = ""):
"""Record that a memory supersedes another."""
self.add_relationship(
new_memory_id,
old_memory_id,
RelationshipType.SUPERSEDES,
description=reason or f"Supersedes {old_memory_id}"
)
def find_conflicts(self, memory_id: str) -> List[MemoryRelationship]:
"""Find memories that conflict with the given memory."""
return self.get_relationships(memory_id, RelationshipType.CONFLICTS_WITH, "both")
def record_conflict(self, memory_a_id: str, memory_b_id: str, description: str = ""):
"""Record a conflict between two memories."""
# Add bidirectional conflict
self.add_relationship(
memory_a_id,
memory_b_id,
RelationshipType.CONFLICTS_WITH,
description=description
)
def find_clusters(self, min_cluster_size: int = 3) -> List[Set[str]]:
"""
Find clusters of closely related memories.
Uses community detection on the relationship graph.
Returns:
List of memory ID sets (clusters)
"""
graph = self._ensure_graph()
if len(graph) < min_cluster_size:
return []
try:
# Use greedy modularity communities
communities = nx.community.greedy_modularity_communities(graph.to_undirected())
return [set(c) for c in communities if len(c) >= min_cluster_size]
except Exception as e:
print(f"Warning: Cluster detection failed: {e}")
return []
def get_central_memories(self, n: int = 5) -> List[Tuple[str, float]]:
"""
Get the most central memories by PageRank.
Returns:
List of (memory_id, centrality_score) tuples
"""
graph = self._ensure_graph()
if len(graph) == 0:
return []
try:
centrality = nx.pagerank(graph)
sorted_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)
return sorted_nodes[:n]
except Exception as e:
print(f"Warning: Centrality calculation failed: {e}")
return []
def get_stats(self) -> Dict:
"""Get statistics about the knowledge graph."""
graph = self._ensure_graph()
stats = {
'total_nodes': graph.number_of_nodes(),
'total_edges': graph.number_of_edges(),
'by_type': {},
'avg_clustering': 0.0,
'is_connected': False,
}
# Count by relationship type
for rels in self._relationships.values():
for rel in rels:
t = rel.type.value
stats['by_type'][t] = stats['by_type'].get(t, 0) + 1
if len(graph) > 0:
try:
stats['avg_clustering'] = nx.average_clustering(graph.to_undirected())
except:
pass
try:
stats['is_connected'] = nx.is_weakly_connected(graph)
except:
pass
return stats
def export_graph(self, format: str = "gexf") -> str:
"""
Export the knowledge graph in various formats.
Args:
format: Export format (gexf, graphml, adjacency)
Returns:
Path to exported file
"""
graph = self._ensure_graph()
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if format == "gexf":
path = self.base_dir / f"knowledge_graph_{timestamp}.gexf"
nx.write_gexf(graph, str(path))
elif format == "graphml":
path = self.base_dir / f"knowledge_graph_{timestamp}.graphml"
nx.write_graphml(graph, str(path))
elif format == "adjacency":
path = self.base_dir / f"knowledge_graph_{timestamp}.adj"
nx.write_adjlist(graph, str(path))
else:
raise ValueError(f"Unknown format: {format}")
return str(path)
def visualize(self, output_path: Optional[Path] = None) -> Path:
"""
Create a visualization of the knowledge graph.
Returns:
Path to the generated visualization
"""
try:
import matplotlib.pyplot as plt # type: ignore
except ImportError:
raise ImportError("matplotlib is required for visualization")
graph = self._ensure_graph()
if len(graph) == 0:
raise ValueError("Cannot visualize empty graph")
plt.figure(figsize=(12, 10))
# Layout
pos = nx.spring_layout(graph, k=2, iterations=50)
# Draw nodes
node_colors = []
for node in graph.nodes():
# Color by type (would need to look up memory type)
node_colors.append('skyblue')
nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=500, alpha=0.9)
nx.draw_networkx_labels(graph, pos, font_size=8)
# Draw edges with colors by type
edge_colors = {
'related': 'gray',
'depends_on': 'red',
'supersedes': 'orange',
'derived_from': 'green',
'conflicts_with': 'purple'
}
for rel_type, color in edge_colors.items():
edges = [(u, v) for u, v, d in graph.edges(data=True) if d.get('type') == rel_type]
if edges:
nx.draw_networkx_edges(graph, pos, edgelist=edges, edge_color=color,
arrows=True, arrowsize=10, alpha=0.6)
plt.title("AP Knowledge Graph")
plt.axis('off')
if output_path is None:
output_path = self.base_dir / "knowledge_graph.png"
plt.savefig(output_path, bbox_inches='tight', dpi=150)
plt.close()
return output_path