Source code for aiecs.application.knowledge_graph.builder.graph_builder

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
"""
Graph Builder - Main Pipeline Orchestrator

Orchestrates the full document-to-graph conversion pipeline.
"""

import asyncio
from typing import List, Optional, Dict, Any, Callable, cast, TYPE_CHECKING
from dataclasses import dataclass, field
from datetime import datetime
import logging

from aiecs.domain.knowledge_graph.schema.graph_schema import GraphSchema
from aiecs.infrastructure.graph_storage.base import GraphStore
from aiecs.application.knowledge_graph.extractors.base import (
    EntityExtractor,
    RelationExtractor,
)
from aiecs.application.knowledge_graph.fusion.entity_deduplicator import (
    EntityDeduplicator,
)
from aiecs.application.knowledge_graph.fusion.entity_linker import EntityLinker
from aiecs.application.knowledge_graph.fusion.relation_deduplicator import (
    RelationDeduplicator,
)
from aiecs.application.knowledge_graph.validators.relation_validator import (
    RelationValidator,
)

if TYPE_CHECKING:
    from aiecs.llm.protocols import LLMClientProtocol

logger = logging.getLogger(__name__)


@dataclass
class BuildResult:
    """
    Result of graph building operation

    Attributes:
        success: Whether build completed successfully
        entities_added: Number of entities added to graph
        relations_added: Number of relations added to graph
        entities_linked: Number of entities linked to existing entities
        entities_deduplicated: Number of entities deduplicated
        relations_deduplicated: Number of relations deduplicated
        errors: List of errors encountered
        warnings: List of warnings
        metadata: Additional metadata about the build
        start_time: When build started
        end_time: When build ended
        duration_seconds: Total duration in seconds
    """

    success: bool = True
    entities_added: int = 0
    relations_added: int = 0
    entities_linked: int = 0
    entities_deduplicated: int = 0
    relations_deduplicated: int = 0
    errors: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    duration_seconds: float = 0.0


[docs] class GraphBuilder: """ Main pipeline for building knowledge graphs from text The pipeline: 1. Extract entities from text 2. Deduplicate entities 3. Link entities to existing graph 4. Extract relations between entities 5. Validate relations 6. Deduplicate relations 7. Store entities and relations in graph Features: - Async/parallel processing - Progress callbacks - Error handling and recovery - Provenance tracking - Configurable components Example:: # Initialize components entity_extractor = LLMEntityExtractor(schema) relation_extractor = LLMRelationExtractor(schema) # Create builder builder = GraphBuilder( graph_store=store, entity_extractor=entity_extractor, relation_extractor=relation_extractor, schema=schema ) # Build graph from text result = await builder.build_from_text( text="Alice works at Tech Corp.", source="document_1.pdf" ) print(f"Added {result.entities_added} entities, {result.relations_added} relations") """
[docs] def __init__( self, graph_store: GraphStore, entity_extractor: EntityExtractor, relation_extractor: RelationExtractor, schema: Optional[GraphSchema] = None, enable_deduplication: bool = True, enable_linking: bool = True, enable_validation: bool = True, progress_callback: Optional[Callable[[str, float], None]] = None, embedding_client: Optional["LLMClientProtocol"] = None, ): """ Initialize graph builder Args: graph_store: Graph storage to save entities/relations entity_extractor: Entity extractor to use relation_extractor: Relation extractor to use schema: Optional schema for validation enable_deduplication: Enable entity/relation deduplication enable_linking: Enable linking to existing entities enable_validation: Enable relation validation progress_callback: Optional callback for progress updates (message, progress_pct) embedding_client: Optional custom LLM client for generating embeddings """ self.graph_store = graph_store self.entity_extractor = entity_extractor self.relation_extractor = relation_extractor self.schema = schema self.enable_deduplication = enable_deduplication self.enable_linking = enable_linking self.enable_validation = enable_validation self.progress_callback = progress_callback self.embedding_client = embedding_client # Initialize fusion components self.entity_deduplicator = EntityDeduplicator() if enable_deduplication else None self.entity_linker = EntityLinker(graph_store) if enable_linking else None self.relation_deduplicator = RelationDeduplicator() if enable_deduplication else None self.relation_validator = RelationValidator(schema) if enable_validation and schema else None
[docs] @staticmethod def from_config( graph_store: GraphStore, entity_extractor: EntityExtractor, relation_extractor: RelationExtractor, schema: Optional[GraphSchema] = None, enable_deduplication: bool = True, enable_linking: bool = True, enable_validation: bool = True, progress_callback: Optional[Callable[[str, float], None]] = None, ) -> "GraphBuilder": """ Create GraphBuilder with embedding client resolved from configuration This factory method automatically resolves the embedding client from the global Settings configuration using LLMClientFactory. Args: graph_store: Graph storage to save entities/relations entity_extractor: Entity extractor to use relation_extractor: Relation extractor to use schema: Optional schema for validation enable_deduplication: Enable entity/relation deduplication enable_linking: Enable linking to existing entities enable_validation: Enable relation validation progress_callback: Optional callback for progress updates Returns: GraphBuilder instance with configured embedding client Example:: from aiecs.config import get_settings from aiecs.llm.factory import LLMClientFactory # Register custom embedding provider LLMClientFactory.register_custom_provider("my_embedder", my_client) # Set environment variable os.environ["KG_EMBEDDING_PROVIDER"] = "my_embedder" # Create builder with auto-resolved embedding client builder = GraphBuilder.from_config( graph_store=store, entity_extractor=extractor, relation_extractor=rel_extractor ) """ from aiecs.config import get_settings from aiecs.llm import resolve_llm_client settings = get_settings() # Resolve embedding client from configuration embedding_client = None if settings.kg_embedding_provider: try: embedding_client = resolve_llm_client( provider=settings.kg_embedding_provider, model=settings.kg_embedding_model, ) logger.info(f"Using embedding provider: {settings.kg_embedding_provider} " f"with model: {settings.kg_embedding_model}") except Exception as e: logger.warning(f"Failed to resolve embedding client from config: {e}") return GraphBuilder( graph_store=graph_store, entity_extractor=entity_extractor, relation_extractor=relation_extractor, schema=schema, enable_deduplication=enable_deduplication, enable_linking=enable_linking, enable_validation=enable_validation, progress_callback=progress_callback, embedding_client=embedding_client, )
[docs] async def build_from_text( self, text: str, source: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> BuildResult: """ Build knowledge graph from text Args: text: Input text to process source: Optional source identifier (document name, URL, etc.) metadata: Optional metadata to attach to entities/relations Returns: BuildResult with statistics and errors """ result = BuildResult(start_time=datetime.now()) try: self._report_progress("Starting entity extraction", 0.1) # Step 1: Extract entities entities = await self.entity_extractor.extract_entities(text) if not entities: result.warnings.append("No entities extracted from text") return self._finalize_result(result) self._report_progress(f"Extracted {len(entities)} entities", 0.2) # Step 2: Deduplicate entities (within this text) if self.enable_deduplication and self.entity_deduplicator: original_count = len(entities) entities = await self.entity_deduplicator.deduplicate(entities) result.entities_deduplicated = original_count - len(entities) self._report_progress(f"Deduplicated to {len(entities)} entities", 0.3) # Step 3: Link entities to existing graph linked_entities = [] new_entities = [] if self.enable_linking and self.entity_linker: self._report_progress("Linking entities to graph", 0.4) link_results = await self.entity_linker.link_entities(entities) for link_result in link_results: if link_result.linked: linked_entities.append(link_result.existing_entity) result.entities_linked += 1 else: new_entities.append(link_result.new_entity) else: new_entities = entities # Combine linked and new entities for relation extraction all_entities_with_none = linked_entities + new_entities # Filter out None values for relation extraction all_entities = [e for e in all_entities_with_none if e is not None] # Step 4: Extract relations if len(all_entities) >= 2: self._report_progress( f"Extracting relations from {len(all_entities)} entities", 0.5, ) relations = await self.relation_extractor.extract_relations(text, all_entities) self._report_progress(f"Extracted {len(relations)} relations", 0.6) else: relations = [] result.warnings.append("Not enough entities for relation extraction") # Step 5: Validate relations valid_relations = relations if self.enable_validation and self.relation_validator and relations: self._report_progress("Validating relations", 0.7) valid_relations = self.relation_validator.filter_valid_relations(relations, all_entities) invalid_count = len(relations) - len(valid_relations) if invalid_count > 0: result.warnings.append(f"{invalid_count} relations failed validation") # Step 6: Deduplicate relations if self.enable_deduplication and self.relation_deduplicator and valid_relations: original_count = len(valid_relations) valid_relations = await self.relation_deduplicator.deduplicate(valid_relations) result.relations_deduplicated = original_count - len(valid_relations) self._report_progress(f"Deduplicated to {len(valid_relations)} relations", 0.8) # Step 7: Generate embeddings for entities if self.embedding_client and new_entities: self._report_progress("Generating embeddings for entities", 0.85) await self._generate_embeddings_for_entities(new_entities) # Step 8: Store in graph self._report_progress("Storing entities and relations in graph", 0.9) # Add provenance metadata if source or metadata: provenance = {"source": source} if source else {} if metadata: provenance.update(metadata) # Add provenance to entities for entity in new_entities: if not entity.properties: entity.properties = {} entity.properties["_provenance"] = provenance # Add provenance to relations for relation in valid_relations: if not relation.properties: relation.properties = {} relation.properties["_provenance"] = provenance # Store entities for entity in new_entities: await self.graph_store.add_entity(entity) result.entities_added += 1 # Store relations for relation in valid_relations: await self.graph_store.add_relation(relation) result.relations_added += 1 self._report_progress("Build complete", 1.0) except Exception as e: result.success = False result.errors.append(f"Build failed: {str(e)}") return self._finalize_result(result)
[docs] async def build_batch( self, texts: List[str], sources: Optional[List[str]] = None, parallel: bool = True, max_parallel: int = 5, ) -> List[BuildResult]: """ Build graph from multiple texts in batch Args: texts: List of texts to process sources: Optional list of source identifiers (same length as texts) parallel: Process in parallel (default: True) max_parallel: Maximum parallel tasks (default: 5) Returns: List of BuildResult objects (one per text) """ if sources and len(sources) != len(texts): raise ValueError("sources list must match texts list length") if not sources: sources = [f"text_{i}" for i in range(len(texts))] if parallel: # Process in parallel with semaphore for concurrency control semaphore = asyncio.Semaphore(max_parallel) async def process_one(text, source): async with semaphore: return await self.build_from_text(text, source) tasks = [process_one(text, source) for text, source in zip(texts, sources)] results = await asyncio.gather(*tasks, return_exceptions=True) # Handle exceptions for i, result in enumerate(results): if isinstance(result, Exception): error_result = BuildResult(success=False) error_result.errors.append(str(result)) results[i] = error_result return cast(List[BuildResult], results) else: # Process sequentially sequential_results: List[BuildResult] = [] for text, source in zip(texts, sources): result = await self.build_from_text(text, source) sequential_results.append(result) return sequential_results
def _report_progress(self, message: str, progress: float): """ Report progress via callback Args: message: Progress message progress: Progress percentage (0.0-1.0) """ if self.progress_callback: try: self.progress_callback(message, progress) except Exception as e: # Don't let callback errors break the pipeline print(f"Warning: Progress callback error: {e}") def _finalize_result(self, result: BuildResult) -> BuildResult: """ Finalize build result with timing information Args: result: BuildResult to finalize Returns: Finalized BuildResult """ result.end_time = datetime.now() if result.start_time: result.duration_seconds = (result.end_time - result.start_time).total_seconds() return result async def _generate_embeddings_for_entities(self, entities: List[Any], model: Optional[str] = None) -> None: """ Generate embeddings for entities using the configured embedding client Args: entities: List of entities to generate embeddings for model: Optional model name for embedding generation Note: This method modifies entities in-place by setting their embedding attribute. If no embedding client is configured, entities will not have embeddings. """ if not self.embedding_client or not entities: return try: # Prepare texts for embedding (use entity name or string representation) texts = [] for entity in entities: # Try to get a meaningful text representation name = entity.properties.get("name") if entity.properties else None if name: text = f"{entity.entity_type}: {name}" else: text = f"{entity.entity_type}: {entity.id}" texts.append(text) # Generate embeddings embeddings = await self.embedding_client.get_embeddings(texts, model=model) # Assign embeddings to entities for entity, embedding in zip(entities, embeddings): entity.embedding = embedding logger.debug(f"Generated embeddings for {len(entities)} entities") except NotImplementedError: logger.debug("Embedding client does not support get_embeddings()") except Exception as e: logger.warning(f"Failed to generate embeddings: {e}")