Source code for aiecs.application.knowledge_graph.reasoning.inference_engine

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
"""
Inference Engine

Rule-based logical inference over knowledge graphs.
"""

import uuid
import time
from typing import List, Optional, Dict, Any, Set, Tuple
from collections import defaultdict
from aiecs.infrastructure.graph_storage.base import GraphStore
from aiecs.domain.knowledge_graph.models.relation import Relation
from aiecs.domain.knowledge_graph.models.inference_rule import (
    InferenceRule,
    InferenceStep,
    InferenceResult,
    RuleType,
)
from aiecs.domain.knowledge_graph.schema.relation_type import RelationType


[docs] class InferenceCache: """ Cache for inference results Stores previously computed inference results to avoid recomputation. """
[docs] def __init__(self, max_size: int = 1000, ttl_seconds: Optional[float] = None): """ Initialize inference cache Args: max_size: Maximum number of cached entries ttl_seconds: Time-to-live in seconds (None = no expiration) """ self.max_size = max_size self.ttl_seconds = ttl_seconds self._cache: Dict[str, Tuple[InferenceResult, float]] = {} self._access_times: Dict[str, float] = {}
def _make_key( self, relation_type: str, source_id: Optional[str] = None, target_id: Optional[str] = None, ) -> str: """Create cache key""" if source_id and target_id: return f"{relation_type}:{source_id}:{target_id}" elif source_id: return f"{relation_type}:{source_id}:*" elif target_id: return f"{relation_type}:*:{target_id}" else: return f"{relation_type}:*:*"
[docs] def get( self, relation_type: str, source_id: Optional[str] = None, target_id: Optional[str] = None, ) -> Optional[InferenceResult]: """ Get cached inference result Args: relation_type: Relation type source_id: Source entity ID target_id: Target entity ID Returns: Cached result or None """ key = self._make_key(relation_type, source_id, target_id) if key not in self._cache: return None result, cached_time = self._cache[key] # Check TTL if self.ttl_seconds and (time.time() - cached_time) > self.ttl_seconds: del self._cache[key] if key in self._access_times: del self._access_times[key] return None # Update access time self._access_times[key] = time.time() return result
[docs] def put( self, relation_type: str, result: InferenceResult, source_id: Optional[str] = None, target_id: Optional[str] = None, ) -> None: """ Cache inference result Args: relation_type: Relation type result: Inference result to cache source_id: Source entity ID target_id: Target entity ID """ key = self._make_key(relation_type, source_id, target_id) # Evict if cache is full (LRU) if len(self._cache) >= self.max_size and key not in self._cache: # Remove least recently used lru_key = min(self._access_times.items(), key=lambda x: x[1])[0] del self._cache[lru_key] del self._access_times[lru_key] self._cache[key] = (result, time.time()) self._access_times[key] = time.time()
[docs] def clear(self) -> None: """Clear all cached results""" self._cache.clear() self._access_times.clear()
[docs] def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" return { "size": len(self._cache), "max_size": self.max_size, "ttl_seconds": self.ttl_seconds, }
[docs] class InferenceEngine: """ Rule-Based Inference Engine Applies logical inference rules to infer new relations from existing ones. Features: - Transitive inference (A->B, B->C => A->C) - Symmetric inference (A->B => B->A) - Custom inference rules - Result caching - Explainability (trace inference steps) Example: ```python engine = InferenceEngine(graph_store) # Add rules engine.add_rule(InferenceRule( rule_id="transitive_works_for", rule_type=RuleType.TRANSITIVE, relation_type="WORKS_FOR" )) # Infer relations result = await engine.infer_relations( relation_type="WORKS_FOR", max_steps=3 ) print(f"Inferred {len(result.inferred_relations)} relations") print(result.get_explanation_string()) ``` """
[docs] def __init__(self, graph_store: GraphStore, cache: Optional[InferenceCache] = None): """ Initialize inference engine Args: graph_store: Graph storage backend cache: Optional inference cache (creates one if not provided) """ self.graph_store = graph_store self.cache = cache or InferenceCache() self.rules: Dict[str, InferenceRule] = {} self.relation_type_schemas: Dict[str, RelationType] = {}
[docs] def add_rule(self, rule: InferenceRule) -> None: """ Add an inference rule Args: rule: Inference rule to add """ self.rules[rule.rule_id] = rule
[docs] def remove_rule(self, rule_id: str) -> None: """ Remove an inference rule Args: rule_id: ID of rule to remove """ if rule_id in self.rules: del self.rules[rule_id]
[docs] def get_rules(self, relation_type: Optional[str] = None) -> List[InferenceRule]: """ Get inference rules Args: relation_type: Filter by relation type (None = all) Returns: List of inference rules """ rules = list(self.rules.values()) if relation_type: rules = [r for r in rules if r.relation_type == relation_type] return rules
[docs] async def infer_relations( self, relation_type: str, max_steps: int = 10, source_id: Optional[str] = None, target_id: Optional[str] = None, use_cache: bool = True, ) -> InferenceResult: """ Infer relations using enabled rules Args: relation_type: Relation type to infer max_steps: Maximum number of inference steps source_id: Optional source entity ID filter target_id: Optional target entity ID filter use_cache: Whether to use cache Returns: Inference result with inferred relations and steps """ # Check cache if use_cache: cached = self.cache.get(relation_type, source_id, target_id) if cached: return cached time.time() inferred_relations: List[Relation] = [] inference_steps: List[InferenceStep] = [] # Track inferred relation IDs to avoid duplicates visited: Set[str] = set() # Get applicable rules applicable_rules = [rule for rule in self.get_rules(relation_type) if rule.enabled] if not applicable_rules: result = InferenceResult( inferred_relations=[], inference_steps=[], total_steps=0, confidence=0.0, explanation=f"No inference rules enabled for relation type: {relation_type}", ) return result # Get existing relations by traversing the graph # We'll collect relations as we discover them through inference # Start with relations we can find through get_neighbors existing_relations: List[Relation] = [] # For inference, we'll discover relations as we traverse # This is a limitation of the current GraphStore interface # In practice, we'd query for all relations of a type directly # Apply rules iteratively current_relations = existing_relations.copy() step_count = 0 while step_count < max_steps: new_relations = [] for rule in applicable_rules: if rule.rule_type == RuleType.TRANSITIVE: inferred = await self._apply_transitive_rule(rule, current_relations, visited) new_relations.extend(inferred) elif rule.rule_type == RuleType.SYMMETRIC: inferred = await self._apply_symmetric_rule(rule, current_relations, visited) new_relations.extend(inferred) if not new_relations: break # No new relations inferred # Add new relations for rel, step in new_relations: if rel.id not in visited: inferred_relations.append(rel) inference_steps.append(step) visited.add(rel.id) current_relations.append(rel) step_count += 1 # Calculate overall confidence if inference_steps: confidence = sum(step.confidence for step in inference_steps) / len(inference_steps) else: confidence = 0.0 # Create result result = InferenceResult( inferred_relations=inferred_relations, inference_steps=inference_steps, total_steps=step_count, confidence=confidence, explanation=f"Inferred {len(inferred_relations)} relations using {len(applicable_rules)} rules in {step_count} steps", ) # Cache result if use_cache: self.cache.put(relation_type, result, source_id, target_id) return result
async def _get_relations(self, relation_type: str) -> List[Relation]: """ Get all relations of a given type Note: This is a simplified implementation that uses traversal. In production, GraphStore should have a get_relations_by_type method. """ relations: List[Relation] = [] # visited_entities: Set[str] = set() # Reserved for future use # Get all entities (we'll need to traverse to find them) # For now, we'll collect relations as we traverse # This is inefficient but works for the current interface # Try to get relations from paths # We'll use a simple approach: traverse from a few entities # In practice, this should be optimized in GraphStore return relations async def _apply_transitive_rule(self, rule: InferenceRule, relations: List[Relation], visited: Set[str]) -> List[Tuple[Relation, InferenceStep]]: """ Apply transitive rule: A->B, B->C => A->C Args: rule: Transitive rule to apply relations: Existing relations visited: Set of already inferred relation IDs Returns: List of (inferred_relation, inference_step) tuples """ inferred = [] # Build index: source -> target source_to_targets: Dict[str, List[Relation]] = defaultdict(list) for rel in relations: if rel.relation_type == rule.relation_type: source_to_targets[rel.source_id].append(rel) # Find transitive chains for rel1 in relations: if rel1.relation_type != rule.relation_type: continue # rel1: A -> B # Find relations where B is source: B -> C for rel2 in source_to_targets.get(rel1.target_id, []): # Check if A -> C already exists inferred_id = f"inf_{rel1.source_id}_{rel2.target_id}_{rule.relation_type}" if inferred_id in visited: continue # Check if relation already exists existing = await self.graph_store.get_relation(inferred_id) if existing: continue # Create inferred relation # Confidence decays with each step confidence = min(rel1.weight, rel2.weight) * (1.0 - rule.confidence_decay) inferred_rel = Relation( id=inferred_id, relation_type=rule.relation_type, source_id=rel1.source_id, target_id=rel2.target_id, weight=confidence, properties={ "inferred": True, "source_relations": [rel1.id, rel2.id], "rule_id": rule.rule_id, }, ) # Create inference step step = InferenceStep( step_id=f"step_{uuid.uuid4().hex[:8]}", inferred_relation=inferred_rel, source_relations=[rel1, rel2], rule=rule, confidence=confidence, explanation=f"Transitive: {rel1.source_id} -> {rel1.target_id} -> {rel2.target_id} => {rel1.source_id} -> {rel2.target_id}", ) inferred.append((inferred_rel, step)) return inferred async def _apply_symmetric_rule(self, rule: InferenceRule, relations: List[Relation], visited: Set[str]) -> List[Tuple[Relation, InferenceStep]]: """ Apply symmetric rule: A->B => B->A Args: rule: Symmetric rule to apply relations: Existing relations visited: Set of already inferred relation IDs Returns: List of (inferred_relation, inference_step) tuples """ inferred = [] # Build set of existing relations (source, target) pairs existing_pairs = set() for rel in relations: if rel.relation_type == rule.relation_type: existing_pairs.add((rel.source_id, rel.target_id)) # Find relations that need symmetric inference for rel in relations: if rel.relation_type != rule.relation_type: continue # Check if reverse already exists reverse_pair = (rel.target_id, rel.source_id) if reverse_pair in existing_pairs: continue # Check if already inferred inferred_id = f"inf_{rel.target_id}_{rel.source_id}_{rule.relation_type}" if inferred_id in visited: continue # Check if relation already exists existing = await self.graph_store.get_relation(inferred_id) if existing: continue # Create inferred relation # Confidence slightly lower than original confidence = rel.weight * (1.0 - rule.confidence_decay) inferred_rel = Relation( id=inferred_id, relation_type=rule.relation_type, source_id=rel.target_id, target_id=rel.source_id, weight=confidence, properties={ "inferred": True, "source_relations": [rel.id], "rule_id": rule.rule_id, }, ) # Create inference step step = InferenceStep( step_id=f"step_{uuid.uuid4().hex[:8]}", inferred_relation=inferred_rel, source_relations=[rel], rule=rule, confidence=confidence, explanation=f"Symmetric: {rel.source_id} -> {rel.target_id} => {rel.target_id} -> {rel.source_id}", ) inferred.append((inferred_rel, step)) return inferred
[docs] def get_inference_trace(self, result: InferenceResult) -> List[str]: """ Get human-readable trace of inference steps Args: result: Inference result Returns: List of trace strings """ trace = [] trace.append(f"Inference trace for {result.total_steps} steps:") for i, step in enumerate(result.inference_steps, 1): trace.append(f" Step {i}: {step.explanation}") trace.append(f" Confidence: {step.confidence:.2f}") trace.append(f" Rule: {step.rule.rule_id} ({step.rule.rule_type})") return trace