Source code for aiecs.application.knowledge_graph.search.reranker_strategies

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
"""
Reranking Strategy Implementations

Concrete implementations of reranking strategies for different signals:
- Text similarity (BM25, Jaccard)
- Semantic similarity (vector embeddings)
- Structural importance (PageRank, centrality)
- Hybrid combination
"""

from typing import List, Optional, Dict
import numpy as np

from aiecs.application.knowledge_graph.search.reranker import RerankerStrategy
from aiecs.application.knowledge_graph.search.text_similarity import (
    BM25Scorer,
    jaccard_similarity_text,
    cosine_similarity_text,
)
from aiecs.domain.knowledge_graph.models.entity import Entity
from aiecs.infrastructure.graph_storage.base import GraphStore


[docs] class TextSimilarityReranker(RerankerStrategy): """ Text similarity reranker using BM25 and Jaccard similarity Combines BM25 (term-based relevance) and Jaccard (set overlap) scores to rerank entities based on text similarity to query. Example:: reranker = TextSimilarityReranker( bm25_weight=0.7, jaccard_weight=0.3 ) scores = await reranker.score("machine learning", entities) """
[docs] def __init__( self, bm25_weight: float = 0.7, jaccard_weight: float = 0.3, property_keys: Optional[List[str]] = None, ): """ Initialize TextSimilarityReranker Args: bm25_weight: Weight for BM25 scores (0.0-1.0) jaccard_weight: Weight for Jaccard scores (0.0-1.0) property_keys: Optional list of property keys to search (default: all string properties) """ if abs(bm25_weight + jaccard_weight - 1.0) > 1e-6: raise ValueError("bm25_weight + jaccard_weight must equal 1.0") self.bm25_weight = bm25_weight self.jaccard_weight = jaccard_weight self.property_keys = property_keys
@property def name(self) -> str: return "text_similarity" def _extract_text(self, entity: Entity) -> str: """Extract searchable text from entity properties""" text_parts = [] if self.property_keys: # Use specified properties only for key in self.property_keys: value = entity.properties.get(key) if isinstance(value, str): text_parts.append(value) elif isinstance(value, (list, tuple)): text_parts.extend(str(v) for v in value if isinstance(v, str)) else: # Use all string properties for key, value in entity.properties.items(): if isinstance(value, str): text_parts.append(value) elif isinstance(value, (list, tuple)): text_parts.extend(str(v) for v in value if isinstance(v, str)) return " ".join(text_parts)
[docs] async def score(self, query: str, entities: List[Entity], **kwargs) -> List[float]: """ Compute text similarity scores Args: query: Query text entities: Entities to score **kwargs: Additional parameters (ignored) Returns: List of scores (0.0-1.0) """ if not entities: return [] if not query: return [0.0] * len(entities) # Extract text from entities entity_texts = [self._extract_text(entity) for entity in entities] # Compute BM25 scores corpus = entity_texts scorer = BM25Scorer(corpus) bm25_scores = scorer.score(query) # Normalize BM25 scores to [0, 1] if bm25_scores: min_bm25 = min(bm25_scores) max_bm25 = max(bm25_scores) if max_bm25 > min_bm25: bm25_normalized = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores] else: bm25_normalized = [1.0] * len(bm25_scores) else: bm25_normalized = [0.0] * len(entities) # Compute Jaccard scores jaccard_scores = [jaccard_similarity_text(query, text) for text in entity_texts] # Combine scores combined_scores = [self.bm25_weight * bm25 + self.jaccard_weight * jaccard for bm25, jaccard in zip(bm25_normalized, jaccard_scores)] return combined_scores
[docs] class SemanticReranker(RerankerStrategy): """ Semantic reranker using vector cosine similarity Uses entity embeddings to compute semantic similarity to query embedding. Example:: reranker = SemanticReranker() scores = await reranker.score( query="machine learning", entities=entities, query_embedding=[0.1, 0.2, ...] ) """
[docs] def __init__(self): """Initialize SemanticReranker"""
@property def name(self) -> str: return "semantic"
[docs] async def score( self, query: str, entities: List[Entity], query_embedding: Optional[List[float]] = None, **kwargs, ) -> List[float]: """ Compute semantic similarity scores Args: query: Query text (used for fallback if no embedding) entities: Entities to score query_embedding: Optional query embedding vector **kwargs: Additional parameters Returns: List of scores (0.0-1.0) """ if not entities: return [] if query_embedding is None: # No embedding provided, return zero scores return [0.0] * len(entities) query_vec = np.array(query_embedding, dtype=np.float32) query_norm = np.linalg.norm(query_vec) if query_norm == 0: return [0.0] * len(entities) scores = [] for entity in entities: if not entity.embedding: scores.append(0.0) continue entity_vec = np.array(entity.embedding, dtype=np.float32) # Check dimension compatibility if len(query_vec) != len(entity_vec): # Dimension mismatch - return zero score scores.append(0.0) continue entity_norm = np.linalg.norm(entity_vec) if entity_norm == 0: scores.append(0.0) continue # Cosine similarity similarity = np.dot(query_vec, entity_vec) / (query_norm * entity_norm) # Normalize to [0, 1] range normalized = (similarity + 1) / 2 scores.append(float(normalized)) return scores
[docs] class StructuralReranker(RerankerStrategy): """ Structural reranker using graph centrality and PageRank Scores entities based on their structural importance in the graph. Uses PageRank scores and degree centrality. Example:: reranker = StructuralReranker(graph_store) scores = await reranker.score("query", entities) """
[docs] def __init__( self, graph_store: GraphStore, pagerank_weight: float = 0.7, degree_weight: float = 0.3, use_cached_scores: bool = True, ): """ Initialize StructuralReranker Args: graph_store: Graph storage backend pagerank_weight: Weight for PageRank scores (0.0-1.0) degree_weight: Weight for degree centrality (0.0-1.0) use_cached_scores: Whether to cache PageRank scores """ if abs(pagerank_weight + degree_weight - 1.0) > 1e-6: raise ValueError("pagerank_weight + degree_weight must equal 1.0") self.graph_store = graph_store self.pagerank_weight = pagerank_weight self.degree_weight = degree_weight self.use_cached_scores = use_cached_scores self._pagerank_cache: Dict[str, float] = {} self._degree_cache: Dict[str, int] = {}
@property def name(self) -> str: return "structural" async def _compute_pagerank_scores(self, entity_ids: List[str]) -> Dict[str, float]: """Compute or retrieve cached PageRank scores""" # Check cache first if self.use_cached_scores: cached = {eid: self._pagerank_cache.get(eid, 0.0) for eid in entity_ids} if all(score > 0 for score in cached.values()): return cached # Compute PageRank using PersonalizedPageRank from aiecs.application.knowledge_graph.retrieval.retrieval_strategies import ( PersonalizedPageRank, ) ppr = PersonalizedPageRank(self.graph_store) # Use all entities as seeds for global PageRank # In practice, you might want to use seed entities from query context # get_all_entities is available via PaginationMixinProtocol all_entities = await self.graph_store.get_all_entities() seed_ids = [e.id for e in all_entities[: min(10, len(all_entities))]] if not seed_ids: return {eid: 0.0 for eid in entity_ids} ppr_results = await ppr.retrieve( seed_entity_ids=seed_ids, max_results=len(entity_ids) * 2, alpha=0.15, ) # Create score dictionary pagerank_scores = {entity.id: score for entity, score in ppr_results} # Normalize to [0, 1] if pagerank_scores: max_score = max(pagerank_scores.values()) if max_score > 0: pagerank_scores = {eid: score / max_score for eid, score in pagerank_scores.items()} # Update cache if self.use_cached_scores: self._pagerank_cache.update(pagerank_scores) return {eid: pagerank_scores.get(eid, 0.0) for eid in entity_ids} async def _compute_degree_scores(self, entity_ids: List[str]) -> Dict[str, float]: """Compute degree centrality scores""" # Check cache if self.use_cached_scores: cached = {eid: self._degree_cache.get(eid, 0) for eid in entity_ids} if all(deg >= 0 for deg in cached.values()): degrees = cached else: degrees = {} else: degrees = {} # Compute missing degrees for entity_id in entity_ids: if entity_id not in degrees: neighbors_out = await self.graph_store.get_neighbors(entity_id, direction="outgoing") neighbors_in = await self.graph_store.get_neighbors(entity_id, direction="incoming") degree = len(neighbors_out) + len(neighbors_in) degrees[entity_id] = degree if self.use_cached_scores: self._degree_cache[entity_id] = degree # Normalize to [0, 1] if degrees: max_degree = max(degrees.values()) if max_degree > 0: return {eid: deg / max_degree for eid, deg in degrees.items()} return {eid: 0.0 for eid in entity_ids}
[docs] async def score(self, query: str, entities: List[Entity], **kwargs) -> List[float]: """ Compute structural importance scores Args: query: Query text (not used, but required by interface) entities: Entities to score **kwargs: Additional parameters Returns: List of scores (0.0-1.0) """ if not entities: return [] entity_ids = [entity.id for entity in entities] # Compute PageRank scores pagerank_scores = await self._compute_pagerank_scores(entity_ids) # Compute degree centrality scores degree_scores = await self._compute_degree_scores(entity_ids) # Combine scores combined_scores = [self.pagerank_weight * pagerank_scores.get(entity.id, 0.0) + self.degree_weight * degree_scores.get(entity.id, 0.0) for entity in entities] return combined_scores
[docs] class HybridReranker(RerankerStrategy): """ Hybrid reranker combining multiple signals Combines text similarity, semantic similarity, and structural importance into a single score. Example:: reranker = HybridReranker( graph_store=store, text_weight=0.4, semantic_weight=0.4, structural_weight=0.2 ) scores = await reranker.score( query="machine learning", entities=entities, query_embedding=[0.1, 0.2, ...] ) """
[docs] def __init__( self, graph_store: GraphStore, text_weight: float = 0.4, semantic_weight: float = 0.4, structural_weight: float = 0.2, ): """ Initialize HybridReranker Args: graph_store: Graph storage backend text_weight: Weight for text similarity (0.0-1.0) semantic_weight: Weight for semantic similarity (0.0-1.0) structural_weight: Weight for structural importance (0.0-1.0) """ if abs(text_weight + semantic_weight + structural_weight - 1.0) > 1e-6: raise ValueError("Weights must sum to 1.0") self.graph_store = graph_store self.text_weight = text_weight self.semantic_weight = semantic_weight self.structural_weight = structural_weight # Initialize sub-strategies self.text_reranker = TextSimilarityReranker() self.semantic_reranker = SemanticReranker() self.structural_reranker = StructuralReranker(graph_store)
@property def name(self) -> str: return "hybrid"
[docs] async def score( self, query: str, entities: List[Entity], query_embedding: Optional[List[float]] = None, **kwargs, ) -> List[float]: """ Compute hybrid scores combining all signals Args: query: Query text entities: Entities to score query_embedding: Optional query embedding vector **kwargs: Additional parameters Returns: List of scores (0.0-1.0) """ if not entities: return [] # Get scores from each strategy text_scores = await self.text_reranker.score(query, entities, **kwargs) semantic_scores = await self.semantic_reranker.score(query, entities, query_embedding=query_embedding, **kwargs) structural_scores = await self.structural_reranker.score(query, entities, **kwargs) # Combine scores combined_scores = [ self.text_weight * text + self.semantic_weight * semantic + self.structural_weight * structural for text, semantic, structural in zip(text_scores, semantic_scores, structural_scores) ] return combined_scores
[docs] class CrossEncoderReranker(RerankerStrategy): """ Cross-encoder reranker using transformer models (optional) Uses a cross-encoder model to compute semantic relevance between query and entity text. More accurate but slower than bi-encoder. Note: This is a placeholder implementation. For production use, integrate with a cross-encoder model library (e.g., sentence-transformers). Example:: reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2") scores = await reranker.score("machine learning", entities) """
[docs] def __init__(self, model_name: Optional[str] = None, use_gpu: bool = False): """ Initialize CrossEncoderReranker Args: model_name: Optional model name (default: None, uses placeholder) use_gpu: Whether to use GPU (if available) """ self.model_name = model_name self.use_gpu = use_gpu self._model = None
@property def name(self) -> str: return "cross_encoder" def _extract_text(self, entity: Entity) -> str: """Extract text from entity for encoding""" text_parts = [] for key, value in entity.properties.items(): if isinstance(value, str): text_parts.append(value) elif isinstance(value, (list, tuple)): text_parts.extend(str(v) for v in value if isinstance(v, str)) return " ".join(text_parts)
[docs] async def score(self, query: str, entities: List[Entity], **kwargs) -> List[float]: """ Compute cross-encoder scores Args: query: Query text entities: Entities to score **kwargs: Additional parameters Returns: List of scores (0.0-1.0) """ if not entities: return [] if not query: return [0.0] * len(entities) # Placeholder implementation # In production, this would use a cross-encoder model: # # if self._model is None: # from sentence_transformers import CrossEncoder # self._model = CrossEncoder(self.model_name or "cross-encoder/ms-marco-MiniLM-L-6-v2") # # entity_texts = [self._extract_text(entity) for entity in entities] # pairs = [[query, text] for text in entity_texts] # scores = self._model.predict(pairs) # # # Normalize to [0, 1] # scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-10) # return scores.tolist() # Fallback: Use cosine similarity as placeholder entity_texts = [self._extract_text(entity) for entity in entities] scores = [cosine_similarity_text(query, text) for text in entity_texts] return scores