""" AP Knowledge Base - Knowledge Graph Relationship tracking and memory lineage for the Allegro-Primus knowledge base. """ import json from pathlib import Path from datetime import datetime from typing import List, Dict, Optional, Set, Tuple from collections import defaultdict import networkx as nx # type: ignore from memory_types import MemoryEntry, MemoryRelationship, RelationshipType class KnowledgeGraph: """ Knowledge graph for tracking relationships between memories. Provides: - Link related memories - Track memory lineage (derived_from, supersedes) - Query by relationship type - Find memory clusters - Detect conflicts """ def __init__(self, base_dir: Path): self.base_dir = Path(base_dir) self.relationships_path = self.base_dir / "relationships.json" self.lineage_path = self.base_dir / "lineage.json" self._graph: Optional[nx.DiGraph] = None self._relationships: Dict[str, List[MemoryRelationship]] = defaultdict(list) self._lineage: Dict[str, Dict] = {} self._load() def _load(self): """Load relationships and lineage from disk.""" # Load relationships if self.relationships_path.exists(): try: data = json.loads(self.relationships_path.read_text()) for rel_data in data: rel = MemoryRelationship( source_id=rel_data['source_id'], target_id=rel_data['target_id'], type=RelationshipType(rel_data['type']), description=rel_data.get('description', ''), created_at=datetime.fromisoformat(rel_data.get('created_at', datetime.utcnow().isoformat())), strength=rel_data.get('strength', 1.0), metadata=rel_data.get('metadata', {}) ) self._relationships[rel.source_id].append(rel) except Exception as e: print(f"Warning: Failed to load relationships: {e}") # Load lineage if self.lineage_path.exists(): try: self._lineage = json.loads(self.lineage_path.read_text()) except Exception as e: print(f"Warning: Failed to load lineage: {e}") def _save(self): """Save relationships and lineage to disk.""" # Save relationships data = [] for rels in self._relationships.values(): for rel in rels: data.append({ 'source_id': rel.source_id, 'target_id': rel.target_id, 'type': rel.type.value, 'description': rel.description, 'created_at': rel.created_at.isoformat(), 'strength': rel.strength, 'metadata': rel.metadata }) self.relationships_path.write_text(json.dumps(data, indent=2), encoding='utf-8') # Save lineage self.lineage_path.write_text(json.dumps(self._lineage, indent=2), encoding='utf-8') def _ensure_graph(self) -> nx.DiGraph: """Ensure the NetworkX graph is built.""" if self._graph is None: self._graph = nx.DiGraph() # Add all relationships as edges for rels in self._relationships.values(): for rel in rels: self._graph.add_edge( rel.source_id, rel.target_id, type=rel.type.value, strength=rel.strength, description=rel.description ) return self._graph def add_relationship( self, source_id: str, target_id: str, rel_type: RelationshipType, description: str = "", strength: float = 1.0, metadata: Optional[Dict] = None ) -> MemoryRelationship: """ Add a relationship between two memories. Args: source_id: ID of the source memory target_id: ID of the target memory rel_type: Type of relationship description: Human-readable description strength: Relationship strength (0.0-1.0) metadata: Additional metadata Returns: The created relationship """ rel = MemoryRelationship( source_id=source_id, target_id=target_id, type=rel_type, description=description, strength=strength, metadata=metadata or {} ) self._relationships[source_id].append(rel) # Invalidate graph cache self._graph = None # Save to disk self._save() return rel def remove_relationship(self, source_id: str, target_id: str, rel_type: Optional[RelationshipType] = None) -> bool: """ Remove a relationship between memories. Returns: True if a relationship was removed """ if source_id not in self._relationships: return False original_count = len(self._relationships[source_id]) if rel_type: self._relationships[source_id] = [ r for r in self._relationships[source_id] if not (r.target_id == target_id and r.type == rel_type) ] else: self._relationships[source_id] = [ r for r in self._relationships[source_id] if r.target_id != target_id ] removed = len(self._relationships[source_id]) < original_count if removed: self._graph = None self._save() return removed def get_relationships( self, memory_id: str, rel_type: Optional[RelationshipType] = None, direction: str = "outgoing" # "outgoing", "incoming", "both" ) -> List[MemoryRelationship]: """ Get relationships for a memory. Args: memory_id: The memory ID rel_type: Filter by relationship type direction: Which direction to look (outgoing, incoming, both) Returns: List of matching relationships """ results = [] if direction in ("outgoing", "both"): for rel in self._relationships.get(memory_id, []): if rel_type is None or rel.type == rel_type: results.append(rel) if direction in ("incoming", "both"): for source_id, rels in self._relationships.items(): if source_id == memory_id: continue for rel in rels: if rel.target_id == memory_id: if rel_type is None or rel.type == rel_type: # Create a reversed view results.append(MemoryRelationship( source_id=rel.target_id, target_id=rel.source_id, type=rel.type, description=rel.description, created_at=rel.created_at, strength=rel.strength, metadata=rel.metadata )) return results def find_related( self, memory_id: str, max_depth: int = 2, min_strength: float = 0.0 ) -> Dict[str, List[Tuple[str, float]]]: """ Find all memories related to a given memory. Args: memory_id: Starting memory ID max_depth: Maximum traversal depth min_strength: Minimum relationship strength Returns: Dict mapping depth to list of (memory_id, cumulative_strength) tuples """ graph = self._ensure_graph() if memory_id not in graph: return {} results: Dict[int, List[Tuple[str, float]]] = {0: [(memory_id, 1.0)]} visited = {memory_id} for depth in range(1, max_depth + 1): results[depth] = [] for prev_id, prev_strength in results[depth - 1]: for neighbor in graph.neighbors(prev_id): if neighbor in visited: continue edge_data = graph.get_edge_data(prev_id, neighbor) if edge_data and edge_data.get('strength', 1.0) >= min_strength: strength = prev_strength * edge_data.get('strength', 1.0) results[depth].append((neighbor, strength)) visited.add(neighbor) if not results[depth]: break return results def get_lineage(self, memory_id: str) -> Dict: """ Get the full lineage of a memory (what it was derived from). Returns: Dict with 'ancestors', 'descendants', and 'superseded_by' lists """ ancestors = [] descendants = [] superseded_by = None # Walk up the derivation chain current = memory_id visited = set() while current and current not in visited: visited.add(current) rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "outgoing") if rels: parent = rels[0].target_id ancestors.append(parent) current = parent else: break # Walk down the derivation chain visited.clear() current = memory_id while current and current not in visited: visited.add(current) rels = self.get_relationships(current, RelationshipType.DERIVED_FROM, "incoming") for rel in rels: descendants.append(rel.source_id) # Check if superseded sup_rel = self.get_relationships(memory_id, RelationshipType.SUPERSEDES, "incoming") if sup_rel: superseded_by = sup_rel[0].source_id return { 'memory_id': memory_id, 'ancestors': ancestors, 'descendants': descendants, 'superseded_by': superseded_by } def record_derivation(self, new_memory_id: str, parent_memory_id: str, notes: str = ""): """Record that a memory was derived from another.""" self.add_relationship( new_memory_id, parent_memory_id, RelationshipType.DERIVED_FROM, description=notes or f"Derived from {parent_memory_id}" ) # Update lineage tracking if parent_memory_id in self._lineage: generation = self._lineage[parent_memory_id].get('generation', 0) + 1 else: generation = 1 self._lineage[new_memory_id] = { 'parent': parent_memory_id, 'generation': generation, 'created_at': datetime.utcnow().isoformat(), 'notes': notes } self._save() def record_supersession(self, new_memory_id: str, old_memory_id: str, reason: str = ""): """Record that a memory supersedes another.""" self.add_relationship( new_memory_id, old_memory_id, RelationshipType.SUPERSEDES, description=reason or f"Supersedes {old_memory_id}" ) def find_conflicts(self, memory_id: str) -> List[MemoryRelationship]: """Find memories that conflict with the given memory.""" return self.get_relationships(memory_id, RelationshipType.CONFLICTS_WITH, "both") def record_conflict(self, memory_a_id: str, memory_b_id: str, description: str = ""): """Record a conflict between two memories.""" # Add bidirectional conflict self.add_relationship( memory_a_id, memory_b_id, RelationshipType.CONFLICTS_WITH, description=description ) def find_clusters(self, min_cluster_size: int = 3) -> List[Set[str]]: """ Find clusters of closely related memories. Uses community detection on the relationship graph. Returns: List of memory ID sets (clusters) """ graph = self._ensure_graph() if len(graph) < min_cluster_size: return [] try: # Use greedy modularity communities communities = nx.community.greedy_modularity_communities(graph.to_undirected()) return [set(c) for c in communities if len(c) >= min_cluster_size] except Exception as e: print(f"Warning: Cluster detection failed: {e}") return [] def get_central_memories(self, n: int = 5) -> List[Tuple[str, float]]: """ Get the most central memories by PageRank. Returns: List of (memory_id, centrality_score) tuples """ graph = self._ensure_graph() if len(graph) == 0: return [] try: centrality = nx.pagerank(graph) sorted_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True) return sorted_nodes[:n] except Exception as e: print(f"Warning: Centrality calculation failed: {e}") return [] def get_stats(self) -> Dict: """Get statistics about the knowledge graph.""" graph = self._ensure_graph() stats = { 'total_nodes': graph.number_of_nodes(), 'total_edges': graph.number_of_edges(), 'by_type': {}, 'avg_clustering': 0.0, 'is_connected': False, } # Count by relationship type for rels in self._relationships.values(): for rel in rels: t = rel.type.value stats['by_type'][t] = stats['by_type'].get(t, 0) + 1 if len(graph) > 0: try: stats['avg_clustering'] = nx.average_clustering(graph.to_undirected()) except: pass try: stats['is_connected'] = nx.is_weakly_connected(graph) except: pass return stats def export_graph(self, format: str = "gexf") -> str: """ Export the knowledge graph in various formats. Args: format: Export format (gexf, graphml, adjacency) Returns: Path to exported file """ graph = self._ensure_graph() timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") if format == "gexf": path = self.base_dir / f"knowledge_graph_{timestamp}.gexf" nx.write_gexf(graph, str(path)) elif format == "graphml": path = self.base_dir / f"knowledge_graph_{timestamp}.graphml" nx.write_graphml(graph, str(path)) elif format == "adjacency": path = self.base_dir / f"knowledge_graph_{timestamp}.adj" nx.write_adjlist(graph, str(path)) else: raise ValueError(f"Unknown format: {format}") return str(path) def visualize(self, output_path: Optional[Path] = None) -> Path: """ Create a visualization of the knowledge graph. Returns: Path to the generated visualization """ try: import matplotlib.pyplot as plt # type: ignore except ImportError: raise ImportError("matplotlib is required for visualization") graph = self._ensure_graph() if len(graph) == 0: raise ValueError("Cannot visualize empty graph") plt.figure(figsize=(12, 10)) # Layout pos = nx.spring_layout(graph, k=2, iterations=50) # Draw nodes node_colors = [] for node in graph.nodes(): # Color by type (would need to look up memory type) node_colors.append('skyblue') nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=500, alpha=0.9) nx.draw_networkx_labels(graph, pos, font_size=8) # Draw edges with colors by type edge_colors = { 'related': 'gray', 'depends_on': 'red', 'supersedes': 'orange', 'derived_from': 'green', 'conflicts_with': 'purple' } for rel_type, color in edge_colors.items(): edges = [(u, v) for u, v, d in graph.edges(data=True) if d.get('type') == rel_type] if edges: nx.draw_networkx_edges(graph, pos, edgelist=edges, edge_color=color, arrows=True, arrowsize=10, alpha=0.6) plt.title("AP Knowledge Graph") plt.axis('off') if output_path is None: output_path = self.base_dir / "knowledge_graph.png" plt.savefig(output_path, bbox_inches='tight', dpi=150) plt.close() return output_path