Source code for aiecs.domain.context.context_engine

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
"""
ContextEngine: Advanced Context and Session Management Engine

This engine extends TaskContext capabilities to provide comprehensive
session management, conversation tracking, and persistent storage for BaseAIService.

Key Features:
1. Multi-session management (extends TaskContext from single task to multiple sessions)
2. Redis backend storage for persistence and scalability
3. Conversation history management with optimization
4. Performance metrics and analytics
5. Resource and lifecycle management
6. Integration with BaseServiceCheckpointer
"""

from aiecs.core.interface.storage_interface import (
    IStorageBackend,
    ICheckpointerBackend,
    IPermanentStorageBackend,
)
from aiecs.domain.task.task_context import TaskContext, ContextUpdate
import json
import logging
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, cast
from dataclasses import dataclass, asdict, is_dataclass


[docs] class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder to handle datetime objects."""
[docs] def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() return super().default(obj)
# Import TaskContext for base functionality # Import core storage interfaces # Redis client import - use existing infrastructure try: import redis.asyncio as redis from aiecs.infrastructure.persistence.redis_client import get_redis_client REDIS_AVAILABLE = True except ImportError: redis = None # type: ignore[assignment] get_redis_client = None # type: ignore[assignment] REDIS_AVAILABLE = False logger = logging.getLogger(__name__)
[docs] @dataclass class SessionMetrics: """Session-level performance metrics.""" session_id: str user_id: str created_at: datetime last_activity: datetime request_count: int = 0 error_count: int = 0 total_processing_time: float = 0.0 status: str = "active" # active, completed, failed, expired
[docs] def to_dict(self) -> Dict[str, Any]: return { **asdict(self), "created_at": self.created_at.isoformat(), "last_activity": self.last_activity.isoformat(), }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SessionMetrics": data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) data["last_activity"] = datetime.fromisoformat(data["last_activity"]) return cls(**data)
[docs] @dataclass class ConversationMessage: """Structured conversation message.""" role: str # user, assistant, system content: str timestamp: datetime metadata: Optional[Dict[str, Any]] = None
[docs] def to_dict(self) -> Dict[str, Any]: return { "role": self.role, "content": self.content, "timestamp": self.timestamp.isoformat(), "metadata": self.metadata or {}, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ConversationMessage": data = data.copy() data["timestamp"] = datetime.fromisoformat(data["timestamp"]) return cls(**data)
[docs] @dataclass class CompressionConfig: """ Configuration for conversation compression. Provides flexible control over compression behavior with multiple strategies to manage conversation history size and reduce token usage. **Compression Strategies:** - truncate: Fast truncation, keeps most recent N messages (no LLM required) - summarize: LLM-based summarization of older messages - semantic: Embedding-based deduplication of similar messages - hybrid: Combination of multiple strategies applied sequentially **Key Features:** - Automatic compression triggers based on message count - Custom prompt templates for summarization - Configurable similarity thresholds for semantic deduplication - Performance timeouts to prevent long-running operations Attributes: strategy: Compression strategy to use. One of: "truncate", "summarize", "semantic", "hybrid" max_messages: Maximum messages to keep (for truncation strategy) keep_recent: Always keep N most recent messages (applies to all strategies) summary_prompt_template: Custom prompt template for summarization (uses {messages} placeholder) summary_max_tokens: Maximum tokens for summary output include_summary_in_history: Whether to add summary as system message in history similarity_threshold: Similarity threshold for semantic deduplication (0.0-1.0) embedding_model: Embedding model name for semantic deduplication hybrid_strategies: List of strategies to combine for hybrid mode (default: ["truncate", "summarize"]) auto_compress_enabled: Enable automatic compression when threshold exceeded auto_compress_threshold: Message count threshold to trigger auto-compression auto_compress_target: Target message count after auto-compression compression_timeout: Maximum time for compression operation in seconds Examples: # Example 1: Basic truncation configuration config = CompressionConfig( strategy="truncate", max_messages=50, keep_recent=10 ) # Example 2: LLM-based summarization config = CompressionConfig( strategy="summarize", keep_recent=10, summary_max_tokens=500, include_summary_in_history=True ) # Example 3: Semantic deduplication config = CompressionConfig( strategy="semantic", keep_recent=10, similarity_threshold=0.95, embedding_model="text-embedding-ada-002" ) # Example 4: Hybrid strategy (truncate then summarize) config = CompressionConfig( strategy="hybrid", hybrid_strategies=["truncate", "summarize"], keep_recent=10, summary_max_tokens=500 ) # Example 5: Auto-compression enabled config = CompressionConfig( auto_compress_enabled=True, auto_compress_threshold=100, auto_compress_target=50, strategy="summarize", keep_recent=10 ) # Example 6: Custom summarization prompt config = CompressionConfig( strategy="summarize", summary_prompt_template=( "Summarize the following conversation focusing on " "key decisions and action items:\n\n{messages}" ), summary_max_tokens=300 ) """ # Strategy selection strategy: str = "truncate" # truncate, summarize, semantic, hybrid # Truncation settings max_messages: int = 50 # Maximum messages to keep keep_recent: int = 10 # Always keep N most recent messages # Summarization settings (LLM-based) summary_prompt_template: Optional[str] = None # Custom prompt template summary_max_tokens: int = 500 # Max tokens for summary include_summary_in_history: bool = True # Add summary as system message # Semantic deduplication settings (embedding-based) similarity_threshold: float = 0.95 # Messages above this similarity are duplicates embedding_model: str = "text-embedding-ada-002" # Embedding model to use # Hybrid strategy settings hybrid_strategies: Optional[List[str]] = None # Strategies to combine (default: ["truncate", "summarize"]) # Auto-compression triggers auto_compress_enabled: bool = False # Enable automatic compression auto_compress_threshold: int = 100 # Trigger when message count exceeds this auto_compress_target: int = 50 # Target message count after compression # Performance settings compression_timeout: float = 30.0 # Max time for compression operation (seconds)
[docs] def __post_init__(self): """Validate and set defaults.""" if self.hybrid_strategies is None: self.hybrid_strategies = ["truncate", "summarize"] # Validate strategy valid_strategies = ["truncate", "summarize", "semantic", "hybrid"] if self.strategy not in valid_strategies: raise ValueError(f"Invalid strategy '{self.strategy}'. " f"Must be one of: {', '.join(valid_strategies)}")
[docs] class ContextEngine(IStorageBackend, ICheckpointerBackend): """ Advanced Context and Session Management Engine. Implements core storage interfaces to provide comprehensive session management with Redis backend storage for BaseAIService and BaseServiceCheckpointer. This implementation follows the middleware's core interface pattern, enabling dependency inversion and clean architecture. **Key Features:** - Multi-session management with Redis backend - Conversation history management with compression - Performance metrics and analytics - Resource and lifecycle management - Integration with BaseServiceCheckpointer **Compression Strategies:** - truncate: Fast truncation (no LLM required) - summarize: LLM-based summarization - semantic: Embedding-based deduplication - hybrid: Combination of multiple strategies Examples: # Example 1: Basic ContextEngine initialization engine = ContextEngine() await engine.initialize() # Create session session = await engine.create_session( session_id="session-123", user_id="user-456" ) # Add conversation messages await engine.add_conversation_message( session_id="session-123", role="user", content="Hello, I need help" ) # Example 2: ContextEngine with compression (truncation strategy) from aiecs.domain.context.context_engine import CompressionConfig compression_config = CompressionConfig( strategy="truncate", max_messages=50, keep_recent=10 # Always keep 10 most recent messages ) engine = ContextEngine(compression_config=compression_config) await engine.initialize() # Add many messages for i in range(100): await engine.add_conversation_message( session_id="session-123", role="user" if i % 2 == 0 else "assistant", content=f"Message {i}" ) # Compress conversation (truncates to 10 most recent) result = await engine.compress_conversation("session-123") print(f"Compressed from {result['original_count']} to {result['compressed_count']} messages") # Example 3: ContextEngine with LLM-based summarization from aiecs.llm import OpenAIClient llm_client = OpenAIClient() compression_config = CompressionConfig( strategy="summarize", keep_recent=10, # Keep 10 most recent messages summary_max_tokens=500, include_summary_in_history=True ) engine = ContextEngine( compression_config=compression_config, llm_client=llm_client # Required for summarization ) await engine.initialize() # Add conversation for i in range(50): await engine.add_conversation_message( session_id="session-123", role="user" if i % 2 == 0 else "assistant", content=f"Message {i}: Important information about topic {i % 5}" ) # Compress using summarization result = await engine.compress_conversation("session-123", strategy="summarize") print(f"Compressed: {result['original_count']} -> {result['compressed_count']} messages") print(f"Compression ratio: {result['compression_ratio']:.1%}") # Example 4: ContextEngine with semantic deduplication compression_config = CompressionConfig( strategy="semantic", keep_recent=10, similarity_threshold=0.95, # Remove messages >95% similar embedding_model="text-embedding-ada-002" ) engine = ContextEngine( compression_config=compression_config, llm_client=llm_client # Required for embeddings ) await engine.initialize() # Add conversation with similar messages messages = [ "What's the weather?", "What's the weather today?", "Tell me about the weather", "What's the temperature?" ] for msg in messages: await engine.add_conversation_message( session_id="session-123", role="user", content=msg ) # Compress using semantic deduplication result = await engine.compress_conversation("session-123", strategy="semantic") print(f"Removed {result['original_count'] - result['compressed_count']} similar messages") # Example 5: ContextEngine with hybrid compression compression_config = CompressionConfig( strategy="hybrid", hybrid_strategies=["truncate", "summarize"], # Apply truncate then summarize keep_recent=10, summary_max_tokens=500 ) engine = ContextEngine( compression_config=compression_config, llm_client=llm_client ) await engine.initialize() # Compress using hybrid strategy result = await engine.compress_conversation("session-123", strategy="hybrid") # Example 6: Auto-compression on message limit compression_config = CompressionConfig( auto_compress_enabled=True, auto_compress_threshold=100, # Trigger at 100 messages auto_compress_target=50, # Compress to 50 messages strategy="summarize", keep_recent=10 ) engine = ContextEngine( compression_config=compression_config, llm_client=llm_client ) await engine.initialize() # Add messages - auto-compression triggers at 100 for i in range(105): await engine.add_conversation_message( session_id="session-123", role="user" if i % 2 == 0 else "assistant", content=f"Message {i}" ) # Check if auto-compression was triggered result = await engine.auto_compress_on_limit("session-123") if result: print(f"Auto-compressed: {result['original_count']} -> {result['compressed_count']}") # Example 7: Custom compression prompt template compression_config = CompressionConfig( strategy="summarize", summary_prompt_template=( "Summarize the following conversation focusing on key decisions, " "action items, and important facts. Keep it concise:\n\n{messages}" ), summary_max_tokens=300 ) engine = ContextEngine( compression_config=compression_config, llm_client=llm_client ) await engine.initialize() # Compress with custom prompt result = await engine.compress_conversation("session-123") # Example 8: Get compressed context in different formats engine = ContextEngine(compression_config=compression_config, llm_client=llm_client) await engine.initialize() # Get as formatted string context_string = await engine.get_compressed_context( session_id="session-123", format="string", compress_first=True # Compress before returning ) print(context_string) # Get as messages list messages = await engine.get_compressed_context( session_id="session-123", format="messages", compress_first=False # Use existing compressed version ) # Get as dictionary context_dict = await engine.get_compressed_context( session_id="session-123", format="dict" ) # Example 9: Runtime compression config override engine = ContextEngine( compression_config=CompressionConfig(strategy="truncate"), llm_client=llm_client ) await engine.initialize() # Override compression config for specific operation custom_config = CompressionConfig( strategy="summarize", summary_max_tokens=1000 ) result = await engine.compress_conversation( session_id="session-123", config_override=custom_config ) # Example 10: Compression with custom LLM client class CustomLLMClient: provider_name = "custom" async def generate_text(self, messages, **kwargs): # Custom summarization logic return LLMResponse(content="Custom summary...") async def get_embeddings(self, texts, model): # Custom embedding logic return [[0.1] * 1536 for _ in texts] custom_llm = CustomLLMClient() compression_config = CompressionConfig(strategy="semantic") engine = ContextEngine( compression_config=compression_config, llm_client=custom_llm # Custom LLM client for compression ) await engine.initialize() # Compress using custom LLM client result = await engine.compress_conversation("session-123", strategy="semantic") """
[docs] def __init__( self, use_existing_redis: bool = True, compression_config: Optional[CompressionConfig] = None, llm_client: Optional[Any] = None, permanent_backend: Optional[IPermanentStorageBackend] = None, ): """ Initialize ContextEngine. Args: use_existing_redis: Whether to use the existing Redis client from infrastructure (已弃用: 现在总是创建独立的 RedisClient 实例以避免事件循环冲突) compression_config: Optional compression configuration for conversation compression llm_client: Optional LLM client for summarization and embeddings (must implement LLMClientProtocol) permanent_backend: Optional backend for dual-write disk persistence (e.g. ClickHouse). Writes are fire-and-forget; failures do not block Redis path. """ self.use_existing_redis = use_existing_redis self.redis_client: Optional[redis.Redis] = None self._redis_client_wrapper: Optional[Any] = None # RedisClient 包装器实例 self._permanent_backend: Optional[IPermanentStorageBackend] = permanent_backend # Fallback to memory storage if Redis not available self._memory_sessions: Dict[str, SessionMetrics] = {} self._memory_conversations: Dict[str, List[ConversationMessage]] = {} self._memory_contexts: Dict[str, TaskContext] = {} self._memory_checkpoints: Dict[str, Dict[str, Any]] = {} # Configuration self.session_ttl = 3600 * 24 # 24 hours default TTL self.conversation_limit = 1000 # Max messages per conversation self.checkpoint_ttl = 3600 * 24 * 7 # 7 days for checkpoints # Compression configuration (Phase 6) self.compression_config = compression_config or CompressionConfig() self.llm_client = llm_client # Metrics self._global_metrics = { "total_sessions": 0, "active_sessions": 0, "total_messages": 0, "total_checkpoints": 0, } logger.info(f"ContextEngine initialized with compression strategy: {self.compression_config.strategy}")
[docs] async def initialize(self) -> bool: """Initialize Redis connection and validate setup.""" if not REDIS_AVAILABLE: logger.warning("Redis not available, using memory storage") return True try: # ✅ 修复方案:在当前事件循环中创建新的 RedisClient 实例 # # 问题根源: # - 全局 RedisClient 单例在应用启动的事件循环A中创建 # - ContextEngine 可能在不同的事件循环B中被初始化(例如在请求处理中) # - redis.asyncio 的连接池绑定到创建时的事件循环 # - 跨事件循环使用会导致 "Task got Future attached to a different loop" 错误 # # 解决方案: # - 为每个 ContextEngine 实例创建独立的 RedisClient # - 使用 RedisClient 包装器保持架构一致性 # - 在当前事件循环中初始化,确保事件循环匹配 from aiecs.infrastructure.persistence.redis_client import ( RedisClient, ) # 创建专属的 RedisClient 实例(在当前事件循环中) self._redis_client_wrapper = RedisClient() await self._redis_client_wrapper.initialize() # 获取底层 redis.Redis 客户端用于现有代码 self.redis_client = await self._redis_client_wrapper.get_client() # Test connection await self.redis_client.ping() logger.info("ContextEngine connected to Redis successfully using RedisClient wrapper in current event loop") # Initialize permanent backend for dual-write (ClickHouse, etc.) if self._permanent_backend: try: if await self._permanent_backend.initialize(): logger.info("ContextEngine dual-write to permanent backend enabled") else: logger.warning("Permanent backend init failed, continuing without dual-write") self._permanent_backend = None except Exception as e: logger.warning(f"Permanent backend init error: {e}, continuing without dual-write") self._permanent_backend = None return True except Exception as e: logger.error(f"Failed to connect to Redis: {e}") logger.warning("Falling back to memory storage") self.redis_client = None self._redis_client_wrapper = None return False
[docs] async def close(self): """Close Redis connection and permanent backend.""" if hasattr(self, "_redis_client_wrapper") and self._redis_client_wrapper: # 使用 RedisClient 包装器的 close 方法 await self._redis_client_wrapper.close() self._redis_client_wrapper = None self.redis_client = None elif self.redis_client: # 兼容性处理:直接关闭 redis 客户端 await self.redis_client.close() self.redis_client = None if hasattr(self, "_permanent_backend") and self._permanent_backend: try: await self._permanent_backend.close() except Exception as e: logger.warning(f"Error closing permanent backend: {e}") self._permanent_backend = None
async def _fire_permanent(self, coro) -> None: """Fire-and-forget call to permanent backend. Logs errors, never raises.""" if not self._permanent_backend: return try: await coro except Exception as e: logger.debug(f"Permanent backend write failed (non-blocking): {e}") # ==================== Session Management ====================
[docs] async def create_session(self, session_id: str, user_id: str, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Create a new session.""" now = datetime.utcnow() session = SessionMetrics( session_id=session_id, user_id=user_id, created_at=now, last_activity=now, ) # Store session await self._store_session(session) # Dual-write: append session create event to permanent backend if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_session_event( session_id=session_id, user_id=user_id, event_type="create", payload=session.to_dict(), created_at=now.isoformat(), ) ) # Create associated TaskContext task_context = TaskContext( { "user_id": user_id, "chat_id": session_id, "metadata": metadata or {}, } ) await self._store_task_context(session_id, task_context) # Update metrics self._global_metrics["total_sessions"] += 1 self._global_metrics["active_sessions"] += 1 logger.info(f"Created session {session_id} for user {user_id}") return session.to_dict()
[docs] async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get session by ID.""" if self.redis_client: try: data = await self.redis_client.hget("sessions", session_id) # type: ignore[misc] if data: session = SessionMetrics.from_dict(json.loads(data)) return session.to_dict() except Exception as e: logger.error(f"Failed to get session from Redis: {e}") # Fallback to memory memory_session: Optional[SessionMetrics] = self._memory_sessions.get(session_id) return memory_session.to_dict() if memory_session else None
[docs] async def update_session( self, session_id: str, updates: Optional[Dict[str, Any]] = None, increment_requests: bool = False, add_processing_time: float = 0.0, mark_error: bool = False, ) -> bool: """Update session with activity and metrics.""" session_data = await self.get_session(session_id) if not session_data: return False # Convert dict to SessionMetrics if needed session: SessionMetrics if isinstance(session_data, dict): session = SessionMetrics.from_dict(session_data) else: session = session_data # Update activity session.last_activity = datetime.utcnow() # Update metrics if increment_requests: session.request_count += 1 if add_processing_time > 0: session.total_processing_time += add_processing_time if mark_error: session.error_count += 1 # Apply custom updates if updates: for key, value in updates.items(): if hasattr(session, key): setattr(session, key, value) # Store updated session await self._store_session(session) # Dual-write: append session update event if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_session_event( session_id=session_id, user_id=session.user_id, event_type="update", payload=session.to_dict(), created_at=datetime.utcnow().isoformat(), ) ) return True
[docs] async def end_session(self, session_id: str, status: str = "completed") -> bool: """End a session and update metrics.""" session_data = await self.get_session(session_id) if not session_data: return False # Convert dict to SessionMetrics if needed session = SessionMetrics.from_dict(session_data) if isinstance(session_data, dict) else session_data session.status = status session.last_activity = datetime.utcnow() # Store final state await self._store_session(session) # Dual-write: append session end event if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_session_event( session_id=session_id, user_id=session.user_id, event_type="end", payload=session.to_dict(), created_at=datetime.utcnow().isoformat(), ) ) # Update global metrics self._global_metrics["active_sessions"] = max(0, self._global_metrics["active_sessions"] - 1) logger.info(f"Ended session {session_id} with status: {status}") return True
async def _store_session(self, session: SessionMetrics): """Store session to Redis or memory.""" if self.redis_client: try: await self.redis_client.hset( # type: ignore[misc] "sessions", session.session_id, json.dumps(session.to_dict(), cls=DateTimeEncoder), ) await self.redis_client.expire("sessions", self.session_ttl) return except Exception as e: logger.error(f"Failed to store session to Redis: {e}") # Fallback to memory self._memory_sessions[session.session_id] = session # ==================== Conversation Management ====================
[docs] async def add_conversation_message( self, session_id: str, role: str, content: str, metadata: Optional[Dict[str, Any]] = None, ) -> bool: """Add message to conversation history.""" message = ConversationMessage( role=role, content=content, timestamp=datetime.utcnow(), metadata=metadata, ) # Store message await self._store_conversation_message(session_id, message) # Update session activity await self.update_session(session_id) # Update global metrics self._global_metrics["total_messages"] += 1 return True
[docs] async def get_conversation_history(self, session_id: str, limit: int = 50) -> List[Dict[str, Any]]: """Get conversation history for a session.""" if self.redis_client: try: messages_data = await self.redis_client.lrange(f"conversation:{session_id}", -limit, -1) # type: ignore[misc] # Since lpush adds to the beginning, we need to reverse to get # chronological order messages = [ConversationMessage.from_dict(json.loads(msg)) for msg in reversed(messages_data)] return [msg.to_dict() for msg in messages] except Exception as e: logger.error(f"Failed to get conversation from Redis: {e}") # Fallback to memory messages = self._memory_conversations.get(session_id, []) message_list = messages[-limit:] if limit > 0 else messages return [msg.to_dict() for msg in message_list]
async def _store_conversation_message(self, session_id: str, message: ConversationMessage): """Store conversation message to Redis or memory.""" if self.redis_client: try: # Add to list await self.redis_client.lpush( # type: ignore[misc] f"conversation:{session_id}", json.dumps(message.to_dict(), cls=DateTimeEncoder), ) # Trim to limit await self.redis_client.ltrim(f"conversation:{session_id}", -self.conversation_limit, -1) # type: ignore[misc] # Set TTL await self.redis_client.expire(f"conversation:{session_id}", self.session_ttl) # Dual-write: append to permanent backend if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_conversation_message( session_id=session_id, role=message.role, content=message.content, metadata=message.metadata, created_at=message.timestamp.isoformat(), ) ) return except Exception as e: logger.error(f"Failed to store message to Redis: {e}") # Fallback to memory if session_id not in self._memory_conversations: self._memory_conversations[session_id] = [] self._memory_conversations[session_id].append(message) # Trim to limit if len(self._memory_conversations[session_id]) > self.conversation_limit: self._memory_conversations[session_id] = self._memory_conversations[session_id][-self.conversation_limit :] # Dual-write when using memory fallback if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_conversation_message( session_id=session_id, role=message.role, content=message.content, metadata=message.metadata, created_at=message.timestamp.isoformat(), ) ) # ==================== TaskContext Integration ====================
[docs] async def get_task_context(self, session_id: str) -> Optional[TaskContext]: """Get TaskContext for a session.""" if self.redis_client: try: data = await self.redis_client.hget("task_contexts", session_id) # type: ignore[misc] if data: context_data = json.loads(data) # Reconstruct TaskContext from stored data return self._reconstruct_task_context(context_data) except Exception as e: logger.error(f"Failed to get TaskContext from Redis: {e}") # Fallback to memory return self._memory_contexts.get(session_id)
def _sanitize_dataclasses(self, obj: Any) -> Any: """ Recursively convert dataclasses to dictionaries for JSON serialization. This method handles: - Dataclass instances -> dict (via asdict) - Nested dataclasses in dictionaries - Nested dataclasses in lists - Other types -> pass through Args: obj: Object to sanitize Returns: Sanitized object (JSON-serializable) """ # Handle dataclass instances if is_dataclass(obj) and not isinstance(obj, type): logger.debug(f"Converting dataclass {type(obj).__name__} to dict for serialization") # Convert dataclass to dict and recursively sanitize return self._sanitize_dataclasses(asdict(obj)) # Handle dictionaries if isinstance(obj, dict): return {key: self._sanitize_dataclasses(value) for key, value in obj.items()} # Handle lists and tuples if isinstance(obj, (list, tuple)): sanitized_list = [self._sanitize_dataclasses(item) for item in obj] return sanitized_list if isinstance(obj, list) else tuple(sanitized_list) # Handle sets if isinstance(obj, set): return [self._sanitize_dataclasses(item) for item in obj] # All other types pass through return obj async def _store_task_context(self, session_id: str, context: TaskContext): """ Store TaskContext to Redis or memory. Automatically converts dataclasses to dictionaries to ensure JSON serialization compatibility. """ if self.redis_client: try: # Get context dict and sanitize dataclasses context_dict = context.to_dict() sanitized_dict = self._sanitize_dataclasses(context_dict) await self.redis_client.hset( # type: ignore[misc] "task_contexts", session_id, json.dumps(sanitized_dict, cls=DateTimeEncoder), ) await self.redis_client.expire("task_contexts", self.session_ttl) # Dual-write: append task context snapshot if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_task_context_snapshot( session_id=session_id, context_data=sanitized_dict, created_at=datetime.utcnow().isoformat(), ) ) return except Exception as e: logger.error(f"Failed to store TaskContext to Redis: {e}") # Fallback to memory self._memory_contexts[session_id] = context # Dual-write when using memory fallback if self._permanent_backend: context_dict = context.to_dict() sanitized_dict = self._sanitize_dataclasses(context_dict) await self._fire_permanent( self._permanent_backend.append_task_context_snapshot( session_id=session_id, context_data=sanitized_dict, created_at=datetime.utcnow().isoformat(), ) ) def _reconstruct_task_context(self, data: Dict[str, Any]) -> TaskContext: """Reconstruct TaskContext from stored data.""" # Create new TaskContext with stored data context = TaskContext(data) # Restore context history if "context_history" in data: context.context_history = [ ContextUpdate( timestamp=entry["timestamp"], update_type=entry["update_type"], data=entry["data"], metadata=entry["metadata"], ) for entry in data["context_history"] ] return context # ==================== Checkpoint Management (for BaseServiceCheckpointer)
[docs] async def store_checkpoint( self, thread_id: str, checkpoint_id: str, checkpoint_data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None, ) -> bool: """ Store checkpoint data for LangGraph workflows. Automatically converts dataclasses to dictionaries to ensure JSON serialization compatibility. """ # Sanitize checkpoint data to handle dataclasses sanitized_data = self._sanitize_dataclasses(checkpoint_data) sanitized_metadata = self._sanitize_dataclasses(metadata or {}) checkpoint = { "checkpoint_id": checkpoint_id, "thread_id": thread_id, "data": sanitized_data, "metadata": sanitized_metadata, "created_at": datetime.utcnow().isoformat(), } if self.redis_client: try: # Store checkpoint await self.redis_client.hset( # type: ignore[misc] f"checkpoints:{thread_id}", checkpoint_id, json.dumps(checkpoint, cls=DateTimeEncoder), ) # Set TTL await self.redis_client.expire(f"checkpoints:{thread_id}", self.checkpoint_ttl) # Dual-write: append checkpoint to permanent backend if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_checkpoint( thread_id=thread_id, checkpoint_id=checkpoint_id, checkpoint_data=sanitized_data, metadata=sanitized_metadata, created_at=checkpoint["created_at"], ) ) # Update global metrics self._global_metrics["total_checkpoints"] += 1 return True except Exception as e: logger.error(f"Failed to store checkpoint to Redis: {e}") # Fallback to memory key = f"{thread_id}:{checkpoint_id}" self._memory_checkpoints[key] = checkpoint # Dual-write when using memory fallback if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_checkpoint( thread_id=thread_id, checkpoint_id=checkpoint_id, checkpoint_data=sanitized_data, metadata=sanitized_metadata, created_at=checkpoint["created_at"], ) ) return True
[docs] async def get_checkpoint(self, thread_id: str, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]: """Get checkpoint data. If checkpoint_id is None, get the latest.""" if self.redis_client: try: if checkpoint_id: # Get specific checkpoint data = await self.redis_client.hget(f"checkpoints:{thread_id}", checkpoint_id) # type: ignore[misc] if data: return cast(Dict[str, Any], json.loads(data)) else: # Get latest checkpoint checkpoints = await self.redis_client.hgetall(f"checkpoints:{thread_id}") # type: ignore[misc] if checkpoints: # Sort by creation time and get latest latest = max( checkpoints.values(), key=lambda x: json.loads(x)["created_at"], ) return cast(Dict[str, Any], json.loads(latest)) except Exception as e: logger.error(f"Failed to get checkpoint from Redis: {e}") # Fallback to memory if checkpoint_id: key = f"{thread_id}:{checkpoint_id}" return self._memory_checkpoints.get(key) else: # Get latest from memory thread_checkpoints = {k: v for k, v in self._memory_checkpoints.items() if k.startswith(f"{thread_id}:")} if thread_checkpoints: latest_key = max( thread_checkpoints.keys(), key=lambda k: thread_checkpoints[k]["created_at"], ) return thread_checkpoints[latest_key] return None
[docs] async def list_checkpoints(self, thread_id: str, limit: int = 10) -> List[Dict[str, Any]]: """List checkpoints for a thread, ordered by creation time (newest first).""" if self.redis_client: try: checkpoints_data = await self.redis_client.hgetall(f"checkpoints:{thread_id}") # type: ignore[misc] checkpoints = [json.loads(data) for data in checkpoints_data.values()] # Sort by creation time (newest first) checkpoints.sort(key=lambda x: x["created_at"], reverse=True) return checkpoints[:limit] except Exception as e: logger.error(f"Failed to list checkpoints from Redis: {e}") # Fallback to memory thread_checkpoints = [v for k, v in self._memory_checkpoints.items() if k.startswith(f"{thread_id}:")] thread_checkpoints.sort(key=lambda x: x["created_at"], reverse=True) return thread_checkpoints[:limit]
# ==================== Cleanup and Maintenance ====================
[docs] async def cleanup_expired_sessions(self, max_idle_hours: int = 24) -> int: """Clean up expired sessions and associated data.""" cutoff_time = datetime.utcnow() - timedelta(hours=max_idle_hours) cleaned_count = 0 if self.redis_client: try: # Get all sessions sessions_data = await self.redis_client.hgetall("sessions") # type: ignore[misc] expired_sessions = [] for session_id, data in sessions_data.items(): session = SessionMetrics.from_dict(json.loads(data)) if session.last_activity < cutoff_time: expired_sessions.append(session_id) # Clean up expired sessions for session_id in expired_sessions: await self._cleanup_session_data(session_id) cleaned_count += 1 except Exception as e: logger.error(f"Failed to cleanup expired sessions from Redis: {e}") else: # Memory cleanup expired_sessions = [session_id for session_id, session in self._memory_sessions.items() if session.last_activity < cutoff_time] for session_id in expired_sessions: await self._cleanup_session_data(session_id) cleaned_count += 1 if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired sessions") return cleaned_count
async def _cleanup_session_data(self, session_id: str): """Clean up all data associated with a session.""" if self.redis_client: try: # Remove session await self.redis_client.hdel("sessions", session_id) # type: ignore[misc] # Remove conversation await self.redis_client.delete(f"conversation:{session_id}") # Remove task context await self.redis_client.hdel("task_contexts", session_id) # type: ignore[misc] # Remove checkpoints await self.redis_client.delete(f"checkpoints:{session_id}") except Exception as e: logger.error(f"Failed to cleanup session data from Redis: {e}") else: # Memory cleanup self._memory_sessions.pop(session_id, None) self._memory_conversations.pop(session_id, None) self._memory_contexts.pop(session_id, None) # Remove checkpoints checkpoint_keys = [k for k in self._memory_checkpoints.keys() if k.startswith(f"{session_id}:")] for key in checkpoint_keys: self._memory_checkpoints.pop(key, None) # ==================== Metrics and Health ====================
[docs] async def get_metrics(self) -> Dict[str, Any]: """Get comprehensive metrics.""" active_sessions_count = 0 if self.redis_client: try: sessions_data = await self.redis_client.hgetall("sessions") # type: ignore[misc] active_sessions_count = len([s for s in sessions_data.values() if json.loads(s)["status"] == "active"]) except Exception as e: logger.error(f"Failed to get metrics from Redis: {e}") else: active_sessions_count = len([s for s in self._memory_sessions.values() if s.status == "active"]) return { **self._global_metrics, "active_sessions": active_sessions_count, "storage_backend": "redis" if self.redis_client else "memory", "redis_connected": self.redis_client is not None, "timestamp": datetime.utcnow().isoformat(), }
[docs] async def health_check(self) -> Dict[str, Any]: """Perform health check.""" health: Dict[str, Any] = { "status": "healthy", "storage_backend": "redis" if self.redis_client else "memory", "redis_connected": False, "issues": [], } issues: List[str] = health["issues"] # Type narrowing # Check Redis connection if self.redis_client: try: await self.redis_client.ping() health["redis_connected"] = True except Exception as e: issues.append(f"Redis connection failed: {e}") health["status"] = "degraded" # Check memory usage (basic check) if not self.redis_client: total_memory_items = len(self._memory_sessions) + len(self._memory_conversations) + len(self._memory_contexts) + len(self._memory_checkpoints) if total_memory_items > 10000: # Arbitrary threshold issues.append(f"High memory usage: {total_memory_items} items") health["status"] = "warning" health["issues"] = issues # Update health dict return health
# ==================== ICheckpointerBackend Implementation ===============
[docs] async def put_checkpoint( self, thread_id: str, checkpoint_id: str, checkpoint_data: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None, ) -> bool: """Store a checkpoint for LangGraph workflows (ICheckpointerBackend interface).""" return await self.store_checkpoint(thread_id, checkpoint_id, checkpoint_data, metadata)
[docs] async def put_writes( self, thread_id: str, checkpoint_id: str, task_id: str, writes_data: List[tuple], ) -> bool: """Store intermediate writes for a checkpoint (ICheckpointerBackend interface).""" writes_key = f"writes:{thread_id}:{checkpoint_id}:{task_id}" writes_payload = { "thread_id": thread_id, "checkpoint_id": checkpoint_id, "task_id": task_id, "writes": writes_data, "created_at": datetime.utcnow().isoformat(), } if self.redis_client: try: await self.redis_client.hset( # type: ignore[misc] f"checkpoint_writes:{thread_id}", f"{checkpoint_id}:{task_id}", json.dumps(writes_payload, cls=DateTimeEncoder), ) await self.redis_client.expire(f"checkpoint_writes:{thread_id}", self.checkpoint_ttl) # Dual-write: append checkpoint writes if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_checkpoint_writes( thread_id=thread_id, checkpoint_id=checkpoint_id, task_id=task_id, writes_data=writes_data, created_at=str(writes_payload["created_at"]), ) ) return True except Exception as e: logger.error(f"Failed to store writes to Redis: {e}") # Fallback to memory self._memory_checkpoints[writes_key] = writes_payload # Dual-write when using memory fallback if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_checkpoint_writes( thread_id=thread_id, checkpoint_id=checkpoint_id, task_id=task_id, writes_data=writes_data, created_at=str(writes_payload["created_at"]), ) ) return True
[docs] async def get_writes(self, thread_id: str, checkpoint_id: str) -> List[tuple]: """Get intermediate writes for a checkpoint (ICheckpointerBackend interface).""" if self.redis_client: try: writes_data = await self.redis_client.hgetall(f"checkpoint_writes:{thread_id}") # type: ignore[misc] writes = [] for key, data in writes_data.items(): if key.startswith(f"{checkpoint_id}:"): payload = json.loads(data) writes.extend(payload.get("writes", [])) return writes except Exception as e: logger.error(f"Failed to get writes from Redis: {e}") # Fallback to memory writes = [] writes_prefix = f"writes:{thread_id}:{checkpoint_id}:" for key, payload in self._memory_checkpoints.items(): if key.startswith(writes_prefix): writes.extend(payload.get("writes", [])) return writes
# ==================== ITaskContextStorage Implementation ================
[docs] async def store_task_context(self, session_id: str, context: Any) -> bool: """Store TaskContext for a session (ITaskContextStorage interface).""" return bool(await self._store_task_context(session_id, context))
# ==================== Agent Communication and Conversation Isolation ====
[docs] async def create_conversation_session( self, session_id: str, participants: List[Dict[str, Any]], session_type: str, metadata: Optional[Dict[str, Any]] = None, ) -> str: """ Create an isolated conversation session between participants. Args: session_id: Base session ID participants: List of participant dictionaries with id, type, role session_type: Type of conversation ('user_to_mc', 'mc_to_agent', 'agent_to_agent', 'user_to_agent') metadata: Additional session metadata Returns: Generated session key for conversation isolation """ from .conversation_models import ( ConversationSession, ConversationParticipant, ) # Create participant objects participant_objects = [ ConversationParticipant( participant_id=p.get("id") or "", participant_type=p.get("type") or "", participant_role=p.get("role"), metadata=p.get("metadata", {}), ) for p in participants ] # Create conversation session conversation_session = ConversationSession( session_id=session_id, participants=participant_objects, session_type=session_type, created_at=datetime.utcnow(), last_activity=datetime.utcnow(), metadata=metadata or {}, ) # Generate unique session key session_key = conversation_session.generate_session_key() # Store conversation session metadata await self._store_conversation_session(session_key, conversation_session) logger.info(f"Created conversation session: {session_key} (type: {session_type})") return session_key
[docs] async def add_agent_communication_message( self, session_key: str, sender_id: str, sender_type: str, sender_role: Optional[str], recipient_id: str, recipient_type: str, recipient_role: Optional[str], content: str, message_type: str = "communication", metadata: Optional[Dict[str, Any]] = None, ) -> bool: """ Add a message to an agent communication session. Args: session_key: Isolated session key sender_id: ID of the sender sender_type: Type of sender ('master_controller', 'agent', 'user') sender_role: Role of sender (for agents) recipient_id: ID of the recipient recipient_type: Type of recipient recipient_role: Role of recipient (for agents) content: Message content message_type: Type of message metadata: Additional message metadata Returns: Success status """ from .conversation_models import AgentCommunicationMessage # Create agent communication message message = AgentCommunicationMessage( message_id=str(uuid.uuid4()), session_key=session_key, sender_id=sender_id, sender_type=sender_type, sender_role=sender_role, recipient_id=recipient_id, recipient_type=recipient_type, recipient_role=recipient_role, content=content, message_type=message_type, timestamp=datetime.utcnow(), metadata=metadata or {}, ) # Convert to conversation message format and store conv_message_dict = message.to_conversation_message_dict() # Store using existing conversation message infrastructure await self.add_conversation_message( session_id=session_key, role=conv_message_dict["role"], content=conv_message_dict["content"], metadata=conv_message_dict["metadata"], ) # Update session activity await self._update_conversation_session_activity(session_key) logger.debug(f"Added agent communication message to session {session_key}") return True
[docs] async def get_agent_conversation_history( self, session_key: str, limit: int = 50, message_types: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ Get conversation history for an agent communication session. Args: session_key: Isolated session key limit: Maximum number of messages to retrieve message_types: Filter by message types Returns: List of conversation messages """ # Get conversation history using existing infrastructure messages = await self.get_conversation_history(session_key, limit) # Filter by message types if specified if message_types: filtered_messages = [] for msg in messages: if hasattr(msg, "to_dict"): msg_dict = msg.to_dict() else: msg_dict = msg msg_metadata = msg_dict.get("metadata", {}) msg_type = msg_metadata.get("message_type", "communication") if msg_type in message_types: filtered_messages.append(msg_dict) return filtered_messages # Convert messages to dict format return [msg.to_dict() if hasattr(msg, "to_dict") else msg for msg in messages]
async def _store_conversation_session(self, session_key: str, conversation_session) -> None: """Store conversation session metadata.""" session_data = { "session_id": conversation_session.session_id, "participants": [ { "participant_id": p.participant_id, "participant_type": p.participant_type, "participant_role": p.participant_role, "metadata": p.metadata, } for p in conversation_session.participants ], "session_type": conversation_session.session_type, "created_at": conversation_session.created_at.isoformat(), "last_activity": conversation_session.last_activity.isoformat(), "metadata": conversation_session.metadata, } if self.redis_client: try: await self.redis_client.hset( # type: ignore[misc] "conversation_sessions", session_key, json.dumps(session_data, cls=DateTimeEncoder), ) await self.redis_client.expire("conversation_sessions", self.session_ttl) # Dual-write: append conversation session if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_conversation_session( session_key=session_key, session_data=session_data, created_at=session_data.get("created_at"), ) ) return except Exception as e: logger.error(f"Failed to store conversation session to Redis: {e}") # Fallback to memory (extend memory storage) if not hasattr(self, "_memory_conversation_sessions"): self._memory_conversation_sessions = {} self._memory_conversation_sessions[session_key] = session_data # Dual-write when using memory fallback if self._permanent_backend: await self._fire_permanent( self._permanent_backend.append_conversation_session( session_key=session_key, session_data=session_data, created_at=session_data.get("created_at"), ) ) async def _update_conversation_session_activity(self, session_key: str) -> None: """Update last activity timestamp for a conversation session.""" if self.redis_client: try: session_data = await self.redis_client.hget("conversation_sessions", session_key) # type: ignore[misc] if session_data: session_dict = json.loads(session_data) session_dict["last_activity"] = datetime.utcnow().isoformat() await self.redis_client.hset( # type: ignore[misc] "conversation_sessions", session_key, json.dumps(session_dict, cls=DateTimeEncoder), ) return except Exception as e: logger.error(f"Failed to update conversation session activity in Redis: {e}") # Fallback to memory if hasattr(self, "_memory_conversation_sessions") and session_key in self._memory_conversation_sessions: self._memory_conversation_sessions[session_key]["last_activity"] = datetime.utcnow().isoformat() # ==================== Compression Methods (Phase 6) ====================
[docs] async def compress_conversation( self, session_id: str, strategy: Optional[str] = None, config_override: Optional[CompressionConfig] = None, ) -> Dict[str, Any]: """ Compress conversation history using specified strategy. Args: session_id: Session ID to compress strategy: Compression strategy (overrides config if provided) config_override: Override compression config for this operation Returns: Dictionary with compression results: { "success": bool, "strategy": str, "original_count": int, "compressed_count": int, "compression_ratio": float, "tokens_saved": int (if applicable), "time_taken": float } Example: result = await engine.compress_conversation( session_id="session-123", strategy="summarize" ) print(f"Compressed from {result['original_count']} to {result['compressed_count']} messages") """ import time start_time = time.time() # Use config override or default config = config_override or self.compression_config selected_strategy = strategy or config.strategy logger.info(f"Compressing conversation {session_id} using strategy: {selected_strategy}") try: # Get current conversation messages_dict = await self.get_conversation_history(session_id) # Convert dict list to ConversationMessage list messages = [ConversationMessage.from_dict(msg) for msg in messages_dict] original_count = len(messages) if original_count == 0: return { "success": False, "error": "No messages to compress", "original_count": 0, "compressed_count": 0, } # Select compression strategy if selected_strategy == "truncate": compressed_messages = await self._compress_with_truncation(messages, config) elif selected_strategy == "summarize": compressed_messages = await self._compress_with_summarization(messages, config) elif selected_strategy == "semantic": compressed_messages = await self._compress_with_semantic_dedup(messages, config) elif selected_strategy == "hybrid": compressed_messages = await self._compress_with_hybrid(messages, config) else: raise ValueError(f"Unknown compression strategy: {selected_strategy}") compressed_count = len(compressed_messages) compression_ratio = 1.0 - (compressed_count / original_count) if original_count > 0 else 0.0 # Replace conversation history await self._replace_conversation_history(session_id, compressed_messages) time_taken = time.time() - start_time result = { "success": True, "strategy": selected_strategy, "original_count": original_count, "compressed_count": compressed_count, "compression_ratio": compression_ratio, "time_taken": time_taken, } logger.info(f"Compression complete: {original_count} -> {compressed_count} messages " f"({compression_ratio:.1%} reduction) in {time_taken:.2f}s") return result except Exception as e: logger.error(f"Compression failed for session {session_id}: {e}") return { "success": False, "error": str(e), "strategy": selected_strategy, "time_taken": time.time() - start_time, }
async def _compress_with_truncation(self, messages: List[ConversationMessage], config: CompressionConfig) -> List[ConversationMessage]: """ Compress by truncating old messages (fast, no LLM required). Keeps the most recent N messages based on config.keep_recent. Args: messages: List of conversation messages config: Compression configuration Returns: Truncated list of messages """ if len(messages) <= config.keep_recent: return messages # Keep most recent messages truncated = messages[-config.keep_recent :] logger.debug(f"Truncation: kept {len(truncated)} most recent messages " f"(removed {len(messages) - len(truncated)})") return truncated async def _compress_with_summarization(self, messages: List[ConversationMessage], config: CompressionConfig) -> List[ConversationMessage]: """ Compress using LLM-based summarization. Creates a summary of older messages and keeps recent messages intact. Args: messages: List of conversation messages config: Compression configuration Returns: List with summary message + recent messages Raises: ValueError: If no LLM client configured """ if not self.llm_client: raise ValueError("LLM client required for summarization compression. " "Provide llm_client parameter to ContextEngine.") if len(messages) <= config.keep_recent: return messages # Split into messages to summarize and messages to keep messages_to_summarize = messages[: -config.keep_recent] messages_to_keep = messages[-config.keep_recent :] # Build summary prompt summary_prompt = self._build_summary_prompt(messages_to_summarize, config) # Generate summary using LLM from aiecs.llm.clients.base_client import LLMMessage llm_messages = [LLMMessage(role="user", content=summary_prompt)] response = await self.llm_client.generate_text(messages=llm_messages, max_tokens=config.summary_max_tokens) summary_text = response.content # Create summary message summary_message = ConversationMessage( role="system", content=f"[Summary of {len(messages_to_summarize)} previous messages]\n\n{summary_text}", timestamp=datetime.utcnow(), metadata={"type": "summary", "summarized_count": len(messages_to_summarize)}, ) # Combine summary + recent messages if config.include_summary_in_history: compressed = [summary_message] + messages_to_keep else: compressed = messages_to_keep logger.debug(f"Summarization: {len(messages_to_summarize)} messages -> 1 summary, " f"kept {len(messages_to_keep)} recent messages") return compressed def _build_summary_prompt(self, messages: List[ConversationMessage], config: CompressionConfig) -> str: """ Build prompt for summarization. Args: messages: Messages to summarize config: Compression configuration Returns: Prompt string for LLM """ # Use custom template if provided if config.summary_prompt_template: # Format template with messages messages_text = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages]) return config.summary_prompt_template.format(messages=messages_text) # Default template messages_text = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages]) prompt = f"""Please provide a concise summary of the following conversation. Focus on key points, decisions, and important information. Keep the summary under {config.summary_max_tokens} tokens. Conversation: {messages_text} Summary:""" return prompt async def _compress_with_semantic_dedup(self, messages: List[ConversationMessage], config: CompressionConfig) -> List[ConversationMessage]: """ Compress using semantic deduplication (embedding-based). Removes messages that are semantically similar to keep diverse content. Args: messages: List of conversation messages config: Compression configuration Returns: List of semantically diverse messages Raises: ValueError: If no LLM client configured """ if not self.llm_client: raise ValueError("LLM client required for semantic deduplication. " "Provide llm_client parameter to ContextEngine.") if len(messages) <= config.keep_recent: return messages # Get embeddings for all messages texts = [msg.content for msg in messages] try: embeddings = await self.llm_client.get_embeddings(texts=texts, model=config.embedding_model) except NotImplementedError: logger.warning("LLM client does not support embeddings. Falling back to truncation.") return await self._compress_with_truncation(messages, config) # Find diverse messages using embeddings diverse_indices = self._find_diverse_messages(embeddings, config.similarity_threshold, config.keep_recent) # Keep messages at diverse indices compressed = [messages[i] for i in sorted(diverse_indices)] logger.debug(f"Semantic dedup: kept {len(compressed)} diverse messages " f"(removed {len(messages) - len(compressed)} similar messages)") return compressed def _find_diverse_messages(self, embeddings: List[List[float]], similarity_threshold: float, target_count: int) -> List[int]: """ Find diverse messages using embeddings. Uses greedy selection to find messages that are semantically diverse. Args: embeddings: List of embedding vectors similarity_threshold: Similarity threshold for deduplication target_count: Target number of messages to keep Returns: List of indices of diverse messages """ import numpy as np if len(embeddings) <= target_count: return list(range(len(embeddings))) # Convert to numpy array emb_array = np.array(embeddings) # Normalize embeddings for cosine similarity norms = np.linalg.norm(emb_array, axis=1, keepdims=True) emb_normalized = emb_array / (norms + 1e-8) # Greedy selection: always keep most recent messages selected_indices = list(range(len(embeddings) - target_count, len(embeddings))) # For older messages, select diverse ones remaining_indices = list(range(len(embeddings) - target_count)) while remaining_indices and len(selected_indices) < target_count: # Find message most different from selected ones max_min_distance = -1 best_idx = None for idx in remaining_indices: # Calculate similarity to all selected messages similarities = np.dot(emb_normalized[idx], emb_normalized[selected_indices].T) min_similarity = np.min(similarities) if len(similarities) > 0 else 0 # We want maximum minimum distance (most diverse) if min_similarity > max_min_distance: max_min_distance = min_similarity best_idx = idx if best_idx is not None and max_min_distance < similarity_threshold: selected_indices.append(best_idx) remaining_indices.remove(best_idx) else: break return selected_indices async def _replace_conversation_history(self, session_id: str, messages: List[ConversationMessage]) -> None: """ Replace conversation history with compressed messages. Args: session_id: Session ID messages: New list of messages """ if self.redis_client: try: # Clear existing messages await self.redis_client.delete(f"conversation:{session_id}") # Store new messages for msg in messages: await self.redis_client.rpush( # type: ignore[misc] f"conversation:{session_id}", json.dumps(msg.to_dict(), cls=DateTimeEncoder), ) # Set TTL await self.redis_client.expire(f"conversation:{session_id}", self.session_ttl) logger.debug(f"Replaced conversation history for {session_id} with {len(messages)} messages") return except Exception as e: logger.error(f"Failed to replace conversation history in Redis: {e}") # Fallback to memory self._memory_conversations[session_id] = messages logger.debug(f"Replaced conversation history (memory) for {session_id} with {len(messages)} messages") async def _compress_with_hybrid(self, messages: List[ConversationMessage], config: CompressionConfig) -> List[ConversationMessage]: """ Compress using hybrid strategy (combination of multiple strategies). Applies multiple compression strategies in sequence based on config.hybrid_strategies. Args: messages: List of conversation messages config: Compression configuration Returns: Compressed list of messages Example: # Default hybrid: truncate then summarize config = CompressionConfig( strategy="hybrid", hybrid_strategies=["truncate", "summarize"] ) """ compressed = messages # Type narrowing: ensure hybrid_strategies is a list if config.hybrid_strategies is None: config.hybrid_strategies = ["truncate", "summarize"] for strategy in config.hybrid_strategies: if strategy == "truncate": compressed = await self._compress_with_truncation(compressed, config) elif strategy == "summarize": compressed = await self._compress_with_summarization(compressed, config) elif strategy == "semantic": compressed = await self._compress_with_semantic_dedup(compressed, config) else: logger.warning(f"Unknown hybrid strategy: {strategy}, skipping") logger.debug(f"Hybrid compression: {len(messages)} -> {len(compressed)} messages " f"using strategies: {', '.join(config.hybrid_strategies)}") return compressed
[docs] async def auto_compress_on_limit(self, session_id: str) -> Optional[Dict[str, Any]]: """ Automatically compress conversation if it exceeds threshold. Checks if conversation exceeds auto_compress_threshold and compresses to auto_compress_target if needed. Args: session_id: Session ID to check Returns: Compression result dict if compression was triggered, None otherwise Example: # Configure auto-compression config = CompressionConfig( auto_compress_enabled=True, auto_compress_threshold=100, auto_compress_target=50 ) engine = ContextEngine(compression_config=config) # Check and auto-compress if needed result = await engine.auto_compress_on_limit(session_id) if result: print(f"Auto-compressed: {result['original_count']} -> {result['compressed_count']}") """ if not self.compression_config.auto_compress_enabled: return None # Get current message count messages = await self.get_conversation_history(session_id) message_count = len(messages) # Check if threshold exceeded if message_count <= self.compression_config.auto_compress_threshold: return None logger.info(f"Auto-compression triggered for {session_id}: " f"{message_count} messages exceeds threshold of " f"{self.compression_config.auto_compress_threshold}") # Compress conversation result = await self.compress_conversation(session_id) if result.get("success"): logger.info(f"Auto-compression complete for {session_id}: " f"{result['original_count']} -> {result['compressed_count']} messages") return result
[docs] async def get_compressed_context( self, session_id: str, format: str = "messages", compress_first: bool = False, ) -> Any: """ Get conversation context in compressed format. Args: session_id: Session ID format: Output format - "messages", "string", or "dict" compress_first: Whether to compress before returning Returns: Conversation in requested format: - "messages": List[ConversationMessage] - "string": Formatted string - "dict": List[Dict[str, Any]] Example: # Get as formatted string context = await engine.get_compressed_context( session_id="session-123", format="string" ) print(context) # Get as messages, compress first messages = await engine.get_compressed_context( session_id="session-456", format="messages", compress_first=True ) """ # Compress first if requested if compress_first: await self.compress_conversation(session_id) # Get conversation history messages = await self.get_conversation_history(session_id) # Return in requested format if format == "messages": return messages elif format == "string": # Format as string lines = [] for msg in messages: # messages is List[Dict[str, Any]] from get_conversation_history timestamp = msg.get("timestamp", "").strftime("%Y-%m-%d %H:%M:%S") if isinstance(msg.get("timestamp"), datetime) else str(msg.get("timestamp", "")) role = msg.get("role", "") content = msg.get("content", "") lines.append(f"[{timestamp}] {role}: {content}") return "\n\n".join(lines) elif format == "dict": # Return as list of dicts (already dicts from get_conversation_history) return [self._sanitize_for_json(msg) for msg in messages] else: raise ValueError(f"Invalid format '{format}'. Must be 'messages', 'string', or 'dict'")
def _sanitize_for_json(self, obj: Any) -> Any: """ Sanitize object for JSON serialization. Handles common non-serializable types like datetime, dataclasses, etc. Args: obj: Object to sanitize Returns: JSON-serializable version of object Note: This is similar to _sanitize_dataclasses but more general purpose. """ # Use existing sanitization logic return self._sanitize_dataclasses(obj)