520 lines
17 KiB
Python
520 lines
17 KiB
Python
|
|
"""
|
||
|
|
AP Knowledge Base - Knowledge Graph
|
||
|
|
|
||
|
|
Relationship tracking and memory lineage for the Allegro-Primus knowledge base.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Dict, Optional, Set, Tuple
|
||
|
|
from collections import defaultdict
|
||
|
|
import networkx as nx # type: ignore
|
||
|
|
|
||
|
|
from memory_types import MemoryEntry, MemoryRelationship, RelationshipType
|
||
|
|
|
||
|
|
|
||
|
|
class KnowledgeGraph:
|
||
|
|
"""
|
||
|
|
Knowledge graph for tracking relationships between memories.
|
||
|
|
|
||
|
|
Provides:
|
||
|
|
- Link related memories
|
||
|
|
- Track memory lineage (derived_from, supersedes)
|
||
|
|
- Query by relationship type
|
||
|
|
- Find memory clusters
|
||
|
|
- Detect conflicts
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, base_dir: Path):
|
||
|
|
self.base_dir = Path(base_dir)
|
||
|
|
self.relationships_path = self.base_dir / "relationships.json"
|
||
|
|
self.lineage_path = self.base_dir / "lineage.json"
|
||
|
|
|
||
|
|
self._graph: Optional[nx.DiGraph] = None
|
||
|
|
self._relationships: Dict[str, List[MemoryRelationship]] = defaultdict(list)
|
||
|
|
self._lineage: Dict[str, Dict] = {}
|
||
|
|
|
||
|
|
self._load()
|
||
|
|
|
||
|
|
def _load(self):
|
||
|
|
"""Load relationships and lineage from disk."""
|
||
|
|
# Load relationships
|
||
|
|
if self.relationships_path.exists():
|
||
|
|
try:
|
||
|
|
data = json.loads(self.relationships_path.read_text())
|
||
|
|
for rel_data in data:
|
||
|
|
rel = MemoryRelationship(
|
||
|
|
source_id=rel_data['source_id'],
|
||
|
|
target_id=rel_data['target_id'],
|
||
|
|
type=RelationshipType(rel_data['type']),
|
||
|
|
description=rel_data.get('description', ''),
|
||
|
|
created_at=datetime.fromisoformat(rel_data.get('created_at', datetime.utcnow().isoformat())),
|
||
|
|
strength=rel_data.get('strength', 1.0),
|
||
|
|
metadata=rel_data.get('metadata', {})
|
||
|
|
)
|
||
|
|
self._relationships[rel.source_id].append(rel)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Warning: Failed to load relationships: {e}")
|
||
|
|
|
||
|
|
# Load lineage
|
||
|
|
if self.lineage_path.exists():
|
||
|
|
try:
|
||
|
|
self._lineage = json.loads(self.lineage_path.read_text())
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Warning: Failed to load lineage: {e}")
|
||
|
|
|
||
|
|
def _save(self):
|
||
|
|
"""Save relationships and lineage to disk."""
|
||
|
|
# Save relationships
|
||
|
|
data = []
|
||
|
|
for rels in self._relationships.values():
|
||
|
|
for rel in rels:
|
||
|
|
data.append({
|
||
|
|
'source_id': rel.source_id,
|
||
|
|
'target_id': rel.target_id,
|
||
|
|
'type': rel.type.value,
|
||
|
|
'description': rel.description,
|
||
|
|
'created_at': rel.created_at.isoformat(),
|
||
|
|
'strength': rel.strength,
|
||
|
|
'metadata': rel.metadata
|
||
|
|
})
|
||
|
|
|
||
|
|
self.relationships_path.write_text(json.dumps(data, indent=2), encoding='utf-8')
|
||
|
|
|
||
|
|
# Save lineage
|
||
|
|
self.lineage_path.write_text(json.dumps(self._lineage, indent=2), encoding='utf-8')
|
||
|
|
|
||
|
|
def _ensure_graph(self) -> nx.DiGraph:
|
||
|
|
"""Ensure the NetworkX graph is built."""
|
||
|
|
if self._graph is None:
|
||
|
|
self._graph = nx.DiGraph()
|
||
|
|
|
||
|
|
# Add all relationships as edges
|
||
|
|
for rels in self._relationships.values():
|
||
|
|
for rel in rels:
|
||
|
|
self._graph.add_edge(
|
||
|
|
rel.source_id,
|
||
|
|
rel.target_id,
|
||
|
|
type=rel.type.value,
|
||
|
|
strength=rel.strength,
|
||
|
|
description=rel.description
|
||
|
|
)
|
||
|
|
|
||
|
|
return self._graph
|
||
|
|
|
||
|
|
def add_relationship(
|
||
|
|
self,
|
||
|
|
source_id: str,
|
||
|
|
target_id: str,
|
||
|
|
rel_type: RelationshipType,
|
||
|
|
description: str = "",
|
||
|
|
strength: float = 1.0,
|
||
|
|
metadata: Optional[Dict] = None
|
||
|
|
) -> MemoryRelationship:
|
||
|
|
"""
|
||
|
|
Add a relationship between two memories.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
source_id: ID of the source memory
|
||
|
|
target_id: ID of the target memory
|
||
|
|
rel_type: Type of relationship
|
||
|
|
description: Human-readable description
|
||
|
|
strength: Relationship strength (0.0-1.0)
|
||
|
|
metadata: Additional metadata
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The created relationship
|
||
|
|
"""
|
||
|
|
rel = MemoryRelationship(
|
||
|
|
source_id=source_id,
|
||
|
|
target_id=target_id,
|
||
|
|
type=rel_type,
|
||
|
|
description=description,
|
||
|
|
strength=strength,
|
||
|
|
metadata=metadata or {}
|
||
|
|
)
|
||
|
|
|
||
|
|
self._relationships[source_id].append(rel)
|
||
|
|
|
||
|
|
# Invalidate graph cache
|
||
|
|
self._graph = None
|
||
|
|
|
||
|
|
# Save to disk
|
||
|
|
self._save()
|
||
|
|
|
||
|
|
return rel
|
||
|
|
|
||
|
|
def remove_relationship(self, source_id: str, target_id: str, rel_type: Optional[RelationshipType] = None) -> bool:
|
||
|
|
"""
|
||
|
|
Remove a relationship between memories.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if a relationship was removed
|
||
|
|
"""
|
||
|
|
if source_id not in self._relationships:
|
||
|
|
return False
|
||
|
|
|
||
|
|
original_count = len(self._relationships[source_id])
|
||
|
|
|
||
|
|
if rel_type:
|
||
|
|
self._relationships[source_id] = [
|
||
|
|
r for r in self._relationships[source_id]
|
||
|
|
if not (r.target_id == target_id and r.type == rel_type)
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
self._relationships[source_id] = [
|
||
|
|
r for r in self._relationships[source_id]
|
||
|
|
if r.target_id != target_id
|
||
|
|
]
|
||
|
|
|
||
|
|
removed = len(self._relationships[source_id]) < original_count
|
||
|
|
|
||
|
|
if removed:
|
||
|
|
self._graph = None
|
||
|
|
self._save()
|
||
|
|
|
||
|
|
return removed
|
||
|
|
|
||
|
|
def get_relationships(
|
||
|
|
self,
|
||
|
|
memory_id: str,
|
||
|
|
rel_type: Optional[RelationshipType] = None,
|
||
|
|
direction: str = "outgoing" # "outgoing", "incoming", "both"
|
||
|
|
) -> List[MemoryRelationship]:
|
||
|
|
"""
|
||
|
|
Get relationships for a memory.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
memory_id: The memory ID
|
||
|
|
rel_type: Filter by relationship type
|
||
|
|
direction: Which direction to look (outgoing, incoming, both)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of matching relationships
|
||
|
|
"""
|
||
|
|
results = []
|
||
|
|
|
||
|
|
if direction in ("outgoing", "both"):
|
||
|
|
for rel in self._relationships.get(memory_id, []):
|
||
|
|
if rel_type is None or rel.type == rel_type:
|
||
|
|
results.append(rel)
|
||
|
|
|
||
|
|
if direction in ("incoming", "both"):
|
||
|
|
for source_id, rels in self._relationships.items():
|
||
|
|
if source_id == memory_id:
|
||
|
|
continue
|
||
|
|
for rel in rels:
|
||
|
|
if rel.target_id == memory_id:
|
||
|
|
if rel_type is None or rel.type == rel_type:
|
||
|
|
# Create a reversed view
|
||
|
|
results.append(MemoryRelationship(
|
||
|
|
source_id=rel.target_id,
|
||
|
|
target_id=rel.source_id,
|
||
|
|
type=rel.type,
|
||
|
|
description=rel.description,
|
||
|
|
created_at=rel.created_at,
|
||
|
|
strength=rel.strength,
|
||
|
|
metadata=rel.metadata
|
||
|
|
))
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def find_related(
|
||
|
|
self,
|
||
|
|
memory_id: str,
|
||
|
|
max_depth: int = 2,
|
||
|
|
min_strength: float = 0.0
|
||
|
|
) -> Dict[str, List[Tuple[str, float]]]:
|
||
|
|
"""
|
||
|
|
Find all memories related to a given memory.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
memory_id: Starting memory ID
|
||
|
|
max_depth: Maximum traversal depth
|
||
|
|
min_strength: Minimum relationship strength
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict mapping depth to list of (memory_id, cumulative_strength) tuples
|
||
|
|
"""
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
if memory_id not in graph:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
results: Dict[int, List[Tuple[str, float]]] = {0: [(memory_id, 1.0)]}
|
||
|
|
visited = {memory_id}
|
||
|
|
|
||
|
|
for depth in range(1, max_depth + 1):
|
||
|
|
results[depth] = []
|
||
|
|
|
||
|
|
for prev_id, prev_strength in results[depth - 1]:
|
||
|
|
for neighbor in graph.neighbors(prev_id):
|
||
|
|
if neighbor in visited:
|
||
|
|
continue
|
||
|
|
|
||
|
|
edge_data = graph.get_edge_data(prev_id, neighbor)
|
||
|
|
if edge_data and edge_data.get('strength', 1.0) >= min_strength:
|
||
|
|
strength = prev_strength * edge_data.get('strength', 1.0)
|
||
|
|
results[depth].append((neighbor, strength))
|
||
|
|
visited.add(neighbor)
|
||
|
|
|
||
|
|
if not results[depth]:
|
||
|
|
break
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def get_lineage(self, memory_id: str) -> Dict:
|
||
|
|
"""
|
||
|
|
Get the full lineage of a memory (what it was derived from).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict with 'ancestors', 'descendants', and 'superseded_by' lists
|
||
|
|
"""
|
||
|
|
ancestors = []
|
||
|
|
descendants = []
|
||
|
|
superseded_by = None
|
||
|
|
|
||
|
|
# Walk up the derivation chain
|
||
|
|
current = memory_id
|
||
|
|
visited = set()
|
||
|
|
|
||
|
|
while current and current not in visited:
|
||
|
|
visited.add(current)
|
||
|
|
rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "outgoing")
|
||
|
|
|
||
|
|
if rels:
|
||
|
|
parent = rels[0].target_id
|
||
|
|
ancestors.append(parent)
|
||
|
|
current = parent
|
||
|
|
else:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Walk down the derivation chain
|
||
|
|
visited.clear()
|
||
|
|
current = memory_id
|
||
|
|
|
||
|
|
while current and current not in visited:
|
||
|
|
visited.add(current)
|
||
|
|
rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "incoming")
|
||
|
|
|
||
|
|
for rel in rels:
|
||
|
|
descendants.append(rel.source_id)
|
||
|
|
|
||
|
|
# Check if superseded
|
||
|
|
sup_rel = self.get_relationships(memory_id, RelationshipType.SUPERSEDES, "incoming")
|
||
|
|
if sup_rel:
|
||
|
|
superseded_by = sup_rel[0].source_id
|
||
|
|
|
||
|
|
return {
|
||
|
|
'memory_id': memory_id,
|
||
|
|
'ancestors': ancestors,
|
||
|
|
'descendants': descendants,
|
||
|
|
'superseded_by': superseded_by
|
||
|
|
}
|
||
|
|
|
||
|
|
def record_derivation(self, new_memory_id: str, parent_memory_id: str, notes: str = ""):
|
||
|
|
"""Record that a memory was derived from another."""
|
||
|
|
self.add_relationship(
|
||
|
|
new_memory_id,
|
||
|
|
parent_memory_id,
|
||
|
|
RelationshipType.DERIVED_FROM,
|
||
|
|
description=notes or f"Derived from {parent_memory_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Update lineage tracking
|
||
|
|
if parent_memory_id in self._lineage:
|
||
|
|
generation = self._lineage[parent_memory_id].get('generation', 0) + 1
|
||
|
|
else:
|
||
|
|
generation = 1
|
||
|
|
|
||
|
|
self._lineage[new_memory_id] = {
|
||
|
|
'parent': parent_memory_id,
|
||
|
|
'generation': generation,
|
||
|
|
'created_at': datetime.utcnow().isoformat(),
|
||
|
|
'notes': notes
|
||
|
|
}
|
||
|
|
|
||
|
|
self._save()
|
||
|
|
|
||
|
|
def record_supersession(self, new_memory_id: str, old_memory_id: str, reason: str = ""):
|
||
|
|
"""Record that a memory supersedes another."""
|
||
|
|
self.add_relationship(
|
||
|
|
new_memory_id,
|
||
|
|
old_memory_id,
|
||
|
|
RelationshipType.SUPERSEDES,
|
||
|
|
description=reason or f"Supersedes {old_memory_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def find_conflicts(self, memory_id: str) -> List[MemoryRelationship]:
|
||
|
|
"""Find memories that conflict with the given memory."""
|
||
|
|
return self.get_relationships(memory_id, RelationshipType.CONFLICTS_WITH, "both")
|
||
|
|
|
||
|
|
def record_conflict(self, memory_a_id: str, memory_b_id: str, description: str = ""):
|
||
|
|
"""Record a conflict between two memories."""
|
||
|
|
# Add bidirectional conflict
|
||
|
|
self.add_relationship(
|
||
|
|
memory_a_id,
|
||
|
|
memory_b_id,
|
||
|
|
RelationshipType.CONFLICTS_WITH,
|
||
|
|
description=description
|
||
|
|
)
|
||
|
|
|
||
|
|
def find_clusters(self, min_cluster_size: int = 3) -> List[Set[str]]:
|
||
|
|
"""
|
||
|
|
Find clusters of closely related memories.
|
||
|
|
|
||
|
|
Uses community detection on the relationship graph.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of memory ID sets (clusters)
|
||
|
|
"""
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
if len(graph) < min_cluster_size:
|
||
|
|
return []
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Use greedy modularity communities
|
||
|
|
communities = nx.community.greedy_modularity_communities(graph.to_undirected())
|
||
|
|
return [set(c) for c in communities if len(c) >= min_cluster_size]
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Warning: Cluster detection failed: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
def get_central_memories(self, n: int = 5) -> List[Tuple[str, float]]:
|
||
|
|
"""
|
||
|
|
Get the most central memories by PageRank.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of (memory_id, centrality_score) tuples
|
||
|
|
"""
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
if len(graph) == 0:
|
||
|
|
return []
|
||
|
|
|
||
|
|
try:
|
||
|
|
centrality = nx.pagerank(graph)
|
||
|
|
sorted_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)
|
||
|
|
return sorted_nodes[:n]
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Warning: Centrality calculation failed: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
def get_stats(self) -> Dict:
|
||
|
|
"""Get statistics about the knowledge graph."""
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
stats = {
|
||
|
|
'total_nodes': graph.number_of_nodes(),
|
||
|
|
'total_edges': graph.number_of_edges(),
|
||
|
|
'by_type': {},
|
||
|
|
'avg_clustering': 0.0,
|
||
|
|
'is_connected': False,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Count by relationship type
|
||
|
|
for rels in self._relationships.values():
|
||
|
|
for rel in rels:
|
||
|
|
t = rel.type.value
|
||
|
|
stats['by_type'][t] = stats['by_type'].get(t, 0) + 1
|
||
|
|
|
||
|
|
if len(graph) > 0:
|
||
|
|
try:
|
||
|
|
stats['avg_clustering'] = nx.average_clustering(graph.to_undirected())
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
try:
|
||
|
|
stats['is_connected'] = nx.is_weakly_connected(graph)
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return stats
|
||
|
|
|
||
|
|
def export_graph(self, format: str = "gexf") -> str:
|
||
|
|
"""
|
||
|
|
Export the knowledge graph in various formats.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
format: Export format (gexf, graphml, adjacency)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Path to exported file
|
||
|
|
"""
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||
|
|
|
||
|
|
if format == "gexf":
|
||
|
|
path = self.base_dir / f"knowledge_graph_{timestamp}.gexf"
|
||
|
|
nx.write_gexf(graph, str(path))
|
||
|
|
elif format == "graphml":
|
||
|
|
path = self.base_dir / f"knowledge_graph_{timestamp}.graphml"
|
||
|
|
nx.write_graphml(graph, str(path))
|
||
|
|
elif format == "adjacency":
|
||
|
|
path = self.base_dir / f"knowledge_graph_{timestamp}.adj"
|
||
|
|
nx.write_adjlist(graph, str(path))
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown format: {format}")
|
||
|
|
|
||
|
|
return str(path)
|
||
|
|
|
||
|
|
def visualize(self, output_path: Optional[Path] = None) -> Path:
|
||
|
|
"""
|
||
|
|
Create a visualization of the knowledge graph.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Path to the generated visualization
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import matplotlib.pyplot as plt # type: ignore
|
||
|
|
except ImportError:
|
||
|
|
raise ImportError("matplotlib is required for visualization")
|
||
|
|
|
||
|
|
graph = self._ensure_graph()
|
||
|
|
|
||
|
|
if len(graph) == 0:
|
||
|
|
raise ValueError("Cannot visualize empty graph")
|
||
|
|
|
||
|
|
plt.figure(figsize=(12, 10))
|
||
|
|
|
||
|
|
# Layout
|
||
|
|
pos = nx.spring_layout(graph, k=2, iterations=50)
|
||
|
|
|
||
|
|
# Draw nodes
|
||
|
|
node_colors = []
|
||
|
|
for node in graph.nodes():
|
||
|
|
# Color by type (would need to look up memory type)
|
||
|
|
node_colors.append('skyblue')
|
||
|
|
|
||
|
|
nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=500, alpha=0.9)
|
||
|
|
nx.draw_networkx_labels(graph, pos, font_size=8)
|
||
|
|
|
||
|
|
# Draw edges with colors by type
|
||
|
|
edge_colors = {
|
||
|
|
'related': 'gray',
|
||
|
|
'depends_on': 'red',
|
||
|
|
'supersedes': 'orange',
|
||
|
|
'derived_from': 'green',
|
||
|
|
'conflicts_with': 'purple'
|
||
|
|
}
|
||
|
|
|
||
|
|
for rel_type, color in edge_colors.items():
|
||
|
|
edges = [(u, v) for u, v, d in graph.edges(data=True) if d.get('type') == rel_type]
|
||
|
|
if edges:
|
||
|
|
nx.draw_networkx_edges(graph, pos, edgelist=edges, edge_color=color,
|
||
|
|
arrows=True, arrowsize=10, alpha=0.6)
|
||
|
|
|
||
|
|
plt.title("AP Knowledge Graph")
|
||
|
|
plt.axis('off')
|
||
|
|
|
||
|
|
if output_path is None:
|
||
|
|
output_path = self.base_dir / "knowledge_graph.png"
|
||
|
|
|
||
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=150)
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
return output_path
|