Source code for aiecs.application.knowledge_graph.reasoning.query_planner

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
"""
Query Planner

Translates natural language queries to structured graph query plans.
Decomposes complex queries into executable steps and optimizes execution order.
"""

import uuid
import re
from typing import Optional, List, Dict, Any, Set, Union
from aiecs.infrastructure.graph_storage.base import GraphStore
from aiecs.domain.knowledge_graph.models.query import GraphQuery, QueryType
from aiecs.domain.knowledge_graph.models.query_plan import (
    QueryPlan,
    QueryStep,
    QueryOperation,
    OptimizationStrategy,
)
from aiecs.infrastructure.graph_storage.query_optimizer import (
    QueryOptimizer,
    QueryStatisticsCollector,
)

# Import LogicQueryParser for DSL support
try:
    from aiecs.application.knowledge_graph.reasoning.logic_parser import (
        LogicQueryParser,
        ParserError,
    )

    LOGIC_PARSER_AVAILABLE = True
except ImportError:
    LOGIC_PARSER_AVAILABLE = False
    LogicQueryParser = None  # type: ignore[misc, assignment]
    ParserError = None  # type: ignore[misc, assignment]


[docs] class QueryPlanner: """ Query Planning Engine Translates natural language queries into structured, optimized execution plans. Features: - Natural language to graph query translation - Query decomposition (complex queries → multiple steps) - Query optimization (reorder operations for efficiency) - Cost estimation Example:: planner = QueryPlanner(graph_store) # Plan a complex query plan = planner.plan_query( "Who works at companies that Alice knows people at?" ) # Optimize the plan optimized_plan = planner.optimize_plan( plan, strategy=OptimizationStrategy.MINIMIZE_COST ) """
[docs] def __init__( self, graph_store: GraphStore, enable_advanced_optimization: bool = True, schema: Optional[Any] = None, ): """ Initialize query planner Args: graph_store: Graph storage backend for queries enable_advanced_optimization: Enable advanced query optimization (default: True) schema: Optional schema manager for logic query validation """ self.graph_store = graph_store self.schema = schema # Pattern templates for query understanding self.query_patterns = self._initialize_query_patterns() # Advanced query optimizer self._enable_advanced_optimization = enable_advanced_optimization # Type annotations for optional attributes self._optimizer: Optional[QueryOptimizer] self._statistics_collector: Optional[QueryStatisticsCollector] self._logic_parser: Optional[LogicQueryParser] if enable_advanced_optimization: # Collect statistics from graph store collector = QueryStatisticsCollector() statistics = collector.collect_from_graph_store(graph_store) # Initialize optimizer self._optimizer = QueryOptimizer(statistics=statistics) self._statistics_collector = collector else: self._optimizer = None self._statistics_collector = None # Logic query parser (if available) if LOGIC_PARSER_AVAILABLE and schema is not None: self._logic_parser = LogicQueryParser(schema=schema) else: self._logic_parser = None
def _initialize_query_patterns(self) -> List[Dict[str, Any]]: """Initialize query pattern matchers""" return [ { "pattern": r"find (.*?) with (.*?) = (['\"]?.+?['\"]?)", "type": "entity_lookup_by_property", "operations": ["filter"], }, { "pattern": r"who (works at|is employed by) (.*?)", "type": "relation_traversal", "operations": ["entity_lookup", "traversal"], }, { "pattern": r"what (companies|organizations) does (.*?) know people at", "type": "multi_hop_query", "operations": ["entity_lookup", "traversal", "traversal"], }, { "pattern": r"(similar|related) to (.*?)", "type": "vector_search", "operations": ["vector_search"], }, { "pattern": r"path from (.*?) to (.*?)", "type": "path_finding", "operations": ["path_finding"], }, { "pattern": r"neighbors of (.*?)", "type": "neighbor_query", "operations": ["entity_lookup", "traversal"], }, ]
[docs] def plan_query( self, natural_language_query: str, context: Optional[Dict[str, Any]] = None, ) -> QueryPlan: """ Create an execution plan from natural language query Args: natural_language_query: Natural language query string context: Optional context (e.g., embeddings, entity IDs) Returns: Query execution plan Example:: plan = planner.plan_query( "Find papers similar to 'Deep Learning' and their authors" ) """ context = context or {} plan_id = f"plan_{uuid.uuid4().hex[:8]}" # Analyze query structure query_info = self._analyze_query(natural_language_query) # Decompose into steps steps = self._decompose_query(natural_language_query, query_info, context) # Create plan plan = QueryPlan( plan_id=plan_id, original_query=natural_language_query, steps=steps, explanation=self._generate_explanation(steps), metadata={"query_info": query_info}, ) # Calculate total cost plan.total_estimated_cost = plan.calculate_total_cost() return plan
def _analyze_query(self, query: str) -> Dict[str, Any]: """ Analyze query to determine type and complexity Args: query: Natural language query Returns: Query analysis information """ query_lower = query.lower() # Match against known patterns matched_pattern = None for pattern_info in self.query_patterns: if re.search(pattern_info["pattern"], query_lower): matched_pattern = pattern_info break # Determine complexity is_multi_hop = any( keyword in query_lower for keyword in [ "who works at", "people at", "friends of", "colleagues", "through", "connected to", "related through", ] ) has_vector_search = any(keyword in query_lower for keyword in ["similar", "related", "like", "semantically"]) has_path_finding = any(keyword in query_lower for keyword in ["path", "route", "connection", "how to get"]) return { "matched_pattern": matched_pattern, "is_multi_hop": is_multi_hop, "has_vector_search": has_vector_search, "has_path_finding": has_path_finding, "complexity": self._estimate_complexity(query_lower), } def _estimate_complexity(self, query: str) -> str: """Estimate query complexity""" hop_indicators = query.count("who") + query.count("what") + query.count("which") if hop_indicators > 2 or "through" in query: return "high" elif hop_indicators > 0 or any(k in query for k in ["find", "get", "show"]): return "medium" else: return "low" def _decompose_query(self, query: str, query_info: Dict[str, Any], context: Dict[str, Any]) -> List[QueryStep]: """ Decompose query into executable steps Args: query: Natural language query query_info: Query analysis information context: Query context Returns: List of query steps """ steps = [] # Use matched pattern if available if query_info["matched_pattern"]: steps = self._create_steps_from_pattern(query, query_info["matched_pattern"], context) else: # Fall back to generic decomposition steps = self._create_generic_steps(query, query_info, context) return steps def _create_steps_from_pattern(self, query: str, pattern_info: Dict[str, Any], context: Dict[str, Any]) -> List[QueryStep]: """Create steps based on matched pattern""" steps = [] query_type = pattern_info["type"] if query_type == "entity_lookup_by_property": # Single step: filter entities by property steps.append( QueryStep( step_id="step_1", operation=QueryOperation.FILTER, query=GraphQuery( query_type=QueryType.CUSTOM, properties=context.get("properties", {}), max_results=context.get("max_results", 10), ), description="Filter entities by properties", estimated_cost=0.3, ) ) elif query_type == "relation_traversal": # Two steps: lookup entity, traverse relations steps.append( QueryStep( step_id="step_1", operation=QueryOperation.ENTITY_LOOKUP, query=GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_id=context.get("entity_id"), max_results=1, ), description="Look up starting entity", estimated_cost=0.2, ) ) steps.append( QueryStep( step_id="step_2", operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.TRAVERSAL, relation_type=context.get("relation_type"), max_depth=context.get("max_depth", 1), max_results=context.get("max_results", 10), ), depends_on=["step_1"], description="Traverse relations from starting entity", estimated_cost=0.5, ) ) elif query_type == "multi_hop_query": # Multiple hops steps = self._create_multi_hop_steps(query, context) elif query_type == "vector_search": # Single step: vector similarity search steps.append( QueryStep( step_id="step_1", operation=QueryOperation.VECTOR_SEARCH, query=GraphQuery( query_type=QueryType.VECTOR_SEARCH, embedding=context.get("query_embedding"), entity_type=context.get("entity_type"), max_results=context.get("max_results", 10), score_threshold=context.get("score_threshold", 0.7), ), description="Find semantically similar entities", estimated_cost=0.4, ) ) elif query_type == "path_finding": # Single step: find path between entities steps.append( QueryStep( step_id="step_1", operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.PATH_FINDING, source_entity_id=context.get("source_id"), target_entity_id=context.get("target_id"), max_depth=context.get("max_depth", 5), max_results=context.get("max_results", 10), ), description="Find paths between entities", estimated_cost=0.7, ) ) elif query_type == "neighbor_query": # Two steps: lookup + get neighbors steps.append( QueryStep( step_id="step_1", operation=QueryOperation.ENTITY_LOOKUP, query=GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_id=context.get("entity_id"), max_results=1, ), description="Look up central entity", estimated_cost=0.2, ) ) steps.append( QueryStep( step_id="step_2", operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.TRAVERSAL, max_depth=1, max_results=context.get("max_results", 20), ), depends_on=["step_1"], description="Get neighboring entities", estimated_cost=0.4, ) ) return steps def _create_multi_hop_steps(self, query: str, context: Dict[str, Any]) -> List[QueryStep]: """Create steps for multi-hop query""" steps = [] num_hops = context.get("num_hops", 2) # Step 1: Find starting entity steps.append( QueryStep( step_id="step_1", operation=QueryOperation.ENTITY_LOOKUP, query=GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_id=context.get("start_entity_id"), max_results=1, ), description="Find starting entity", estimated_cost=0.2, ) ) # Create hop steps for i in range(num_hops): hop_num = i + 1 step_id = f"step_{hop_num + 1}" depends_on = [f"step_{hop_num}"] steps.append( QueryStep( step_id=step_id, operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.TRAVERSAL, relation_type=context.get(f"hop{hop_num}_relation"), max_depth=1, max_results=context.get("max_results", 20), ), depends_on=depends_on, description=f"Hop {hop_num}: Traverse to next level", estimated_cost=0.4 + (0.1 * i), # Cost increases with depth ) ) return steps def _create_generic_steps(self, query: str, query_info: Dict[str, Any], context: Dict[str, Any]) -> List[QueryStep]: """Create generic steps when no pattern matches""" steps = [] # Priority 1: If start_entity_id is provided, use traversal if context.get("start_entity_id"): # Step 1: Lookup starting entity steps.append( QueryStep( step_id="step_1", operation=QueryOperation.ENTITY_LOOKUP, query=GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_id=context.get("start_entity_id"), max_results=1, ), description="Look up starting entity", estimated_cost=0.2, ) ) # Step 2: Traverse from starting entity target_id = context.get("target_entity_id") if target_id: # Path finding if target is specified steps.append( QueryStep( step_id="step_2", operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.PATH_FINDING, source_entity_id=context.get("start_entity_id"), target_entity_id=target_id, max_depth=context.get("max_hops", 3), max_results=context.get("max_results", 10), ), depends_on=["step_1"], description="Find paths from start to target entity", estimated_cost=0.6, ) ) else: # General traversal if no target steps.append( QueryStep( step_id="step_2", operation=QueryOperation.TRAVERSAL, query=GraphQuery( query_type=QueryType.TRAVERSAL, entity_id=context.get("start_entity_id"), relation_type=(context.get("relation_types", [None])[0] if context.get("relation_types") else None), max_depth=context.get("max_hops", 3), max_results=context.get("max_results", 10), ), depends_on=["step_1"], description="Traverse from starting entity", estimated_cost=0.5, ) ) # Priority 2: If query_embedding is provided, use vector search elif context.get("query_embedding"): steps.append( QueryStep( step_id="step_1", operation=QueryOperation.VECTOR_SEARCH, query=GraphQuery( query_type=QueryType.VECTOR_SEARCH, embedding=context.get("query_embedding"), entity_type=context.get("entity_type"), max_results=context.get("max_results", 10), score_threshold=context.get("score_threshold", 0.5), ), description="Search for relevant entities using vector similarity", estimated_cost=0.5, ) ) # Priority 3: Default fallback - entity lookup by type if entity_type # is provided elif context.get("entity_type"): steps.append( QueryStep( step_id="step_1", operation=QueryOperation.FILTER, query=GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_type=context.get("entity_type"), max_results=context.get("max_results", 10), ), description=f"Filter entities by type: {context.get('entity_type')}", estimated_cost=0.3, ) ) # Priority 4: Last resort - simple vector search (may not work without # embeddings) else: steps.append( QueryStep( step_id="step_1", operation=QueryOperation.VECTOR_SEARCH, query=GraphQuery( query_type=QueryType.VECTOR_SEARCH, embedding=None, # Will need to be generated max_results=context.get("max_results", 10), score_threshold=0.5, ), description="Search for relevant entities (fallback - may not work without embeddings)", estimated_cost=0.5, ) ) return steps def _generate_explanation(self, steps: List[QueryStep]) -> str: """Generate human-readable explanation of plan""" if not steps: return "No steps in plan" if len(steps) == 1: return f"Single-step query: {steps[0].description}" parts = [f"Multi-step query with {len(steps)} steps:"] for i, step in enumerate(steps, 1): parts.append(f"{i}. {step.description}") return "\n".join(parts)
[docs] def optimize_plan( self, plan: QueryPlan, strategy: OptimizationStrategy = OptimizationStrategy.BALANCED, ) -> QueryPlan: """ Optimize query execution plan Args: plan: Original query plan strategy: Optimization strategy Returns: Optimized query plan Example:: optimized = planner.optimize_plan( plan, strategy=OptimizationStrategy.MINIMIZE_COST ) """ if plan.optimized: return plan # Already optimized # Use advanced optimizer if enabled if self._enable_advanced_optimization and self._optimizer: result = self._optimizer.optimize(plan) return result.optimized_plan # Fall back to basic optimization optimized_steps = list(plan.steps) if strategy == OptimizationStrategy.MINIMIZE_COST: optimized_steps = self._optimize_for_cost(optimized_steps) elif strategy == OptimizationStrategy.MINIMIZE_LATENCY: optimized_steps = self._optimize_for_latency(optimized_steps) else: # BALANCED optimized_steps = self._optimize_balanced(optimized_steps) # Create optimized plan optimized_plan = QueryPlan( plan_id=plan.plan_id + "_opt", original_query=plan.original_query, steps=optimized_steps, optimized=True, explanation=plan.explanation + "\n(Optimized)", metadata=plan.metadata, ) optimized_plan.total_estimated_cost = optimized_plan.calculate_total_cost() return optimized_plan
def _optimize_for_cost(self, steps: List[QueryStep]) -> List[QueryStep]: """ Optimize to minimize total cost Strategy: Execute cheaper operations first when possible """ # Group steps by dependency level levels = self._get_dependency_levels(steps) optimized = [] for level_steps in levels: # Sort by cost (ascending) within each level sorted_level = sorted(level_steps, key=lambda s: s.estimated_cost) optimized.extend(sorted_level) return optimized def _optimize_for_latency(self, steps: List[QueryStep]) -> List[QueryStep]: """ Optimize to minimize latency Strategy: Maximize parallelization """ # Already maximized in get_execution_order() # Just return original order return steps def _optimize_balanced(self, steps: List[QueryStep]) -> List[QueryStep]: """ Balanced optimization Strategy: Balance cost and latency """ levels = self._get_dependency_levels(steps) optimized = [] for level_steps in levels: # Sort by cost but not too aggressively # Keep expensive operations that can run in parallel sorted_level = sorted( level_steps, key=lambda s: (s.estimated_cost > 0.7, s.estimated_cost), ) optimized.extend(sorted_level) return optimized def _get_dependency_levels(self, steps: List[QueryStep]) -> List[List[QueryStep]]: """ Group steps by dependency level Returns: List of lists, each containing steps at the same dependency level """ # step_map = {step.step_id: step for step in steps} # Reserved for # future use levels: List[List[QueryStep]] = [] processed: Set[str] = set() while len(processed) < len(steps): current_level = [] for step in steps: if step.step_id in processed: continue # Check if all dependencies are processed if all(dep in processed for dep in step.depends_on): current_level.append(step) if not current_level: break # Should not happen with valid dependencies levels.append(current_level) processed.update(step.step_id for step in current_level) return levels
[docs] def translate_to_graph_query( self, natural_language_query: str, context: Optional[Dict[str, Any]] = None, ) -> GraphQuery: """ Translate natural language to a single graph query For simple queries that don't need decomposition. Args: natural_language_query: Natural language query context: Query context (embeddings, entity IDs, etc.) Returns: Single graph query Example:: query = planner.translate_to_graph_query( "Find entities similar to X", context={"query_embedding": [0.1, 0.2, ...]} ) """ context = context or {} query_lower = natural_language_query.lower() # Determine query type if "similar" in query_lower or "related" in query_lower: return GraphQuery( query_type=QueryType.VECTOR_SEARCH, embedding=context.get("query_embedding"), entity_type=context.get("entity_type"), max_results=context.get("max_results", 10), score_threshold=context.get("score_threshold", 0.7), ) elif "path" in query_lower: return GraphQuery( query_type=QueryType.PATH_FINDING, source_entity_id=context.get("source_id"), target_entity_id=context.get("target_id"), max_depth=context.get("max_depth", 5), max_results=context.get("max_results", 10), ) elif "neighbor" in query_lower or "connected to" in query_lower: return GraphQuery( query_type=QueryType.TRAVERSAL, entity_id=context.get("entity_id"), relation_type=context.get("relation_type"), max_depth=1, max_results=context.get("max_results", 20), ) else: # Default to entity lookup return GraphQuery( query_type=QueryType.ENTITY_LOOKUP, entity_id=context.get("entity_id"), entity_type=context.get("entity_type"), properties=context.get("properties", {}), max_results=context.get("max_results", 10), )
# Advanced Optimization Methods
[docs] def update_statistics(self) -> None: """ Update query statistics from graph store Call this periodically to keep optimizer statistics up-to-date """ if self._enable_advanced_optimization and self._statistics_collector and self._optimizer: statistics = self._statistics_collector.collect_from_graph_store(self.graph_store) self._optimizer.update_statistics(statistics)
[docs] def record_execution_time(self, execution_time_ms: float) -> None: """ Record query execution time for statistics Args: execution_time_ms: Execution time in milliseconds """ if self._statistics_collector: self._statistics_collector.record_execution_time(execution_time_ms)
[docs] def get_optimizer_stats(self) -> Dict[str, Any]: """ Get optimizer statistics Returns: Dictionary with optimizer statistics """ if not self._enable_advanced_optimization or not self._optimizer: return {"enabled": False} return { "enabled": True, "optimizations_performed": self._optimizer.get_optimization_count(), "avg_execution_time_ms": (self._statistics_collector.get_average_execution_time() if self._statistics_collector else 0.0), "p95_execution_time_ms": (self._statistics_collector.get_execution_percentile(0.95) if self._statistics_collector else 0.0), "entity_count": self._optimizer.statistics.entity_count, "relation_count": self._optimizer.statistics.relation_count, "avg_degree": self._optimizer.statistics.avg_degree, }
# ======================================================================== # Logic Query Support # ========================================================================
[docs] def plan_logic_query(self, logic_query: str) -> Union[QueryPlan, List[Any]]: """ Create execution plan from logic query DSL This method parses a logic query (e.g., "Find(Person) WHERE age > 30") and converts it directly to a QueryPlan. Args: logic_query: Logic query string in DSL format Returns: QueryPlan if successful, List[ParserError] if errors occurred Example:: plan = planner.plan_logic_query("Find(Person) WHERE age > 30") if isinstance(plan, list): # Parsing errors for error in plan: print(f"Error at line {error.line}: {error.message}") else: # Success - execute the plan result = await graph_store.execute_plan(plan) """ if not LOGIC_PARSER_AVAILABLE: raise ImportError("Logic parser not available. Install lark-parser.") if self._logic_parser is None: raise ValueError("Logic parser not initialized. Provide schema to QueryPlanner.") # Parse logic query to QueryPlan return self._logic_parser.parse_to_query_plan(logic_query) # type: ignore[no-any-return]
[docs] def supports_logic_queries(self) -> bool: """ Check if logic query support is available Returns: True if logic queries are supported, False otherwise """ return LOGIC_PARSER_AVAILABLE and self._logic_parser is not None