Source code for aiecs.domain.task.task_context

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
import time
import logging
import json
from typing import Dict, Any, Optional, AsyncGenerator, List
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path

logger = logging.getLogger(__name__)


[docs] @dataclass class ContextUpdate: """Represents a single update to the context (e.g., message, metadata, or resource).""" timestamp: float update_type: str # e.g., "message", "metadata", "resource" data: Any # Content of the update (e.g., message text, metadata dict) # Additional metadata (e.g., file paths, model info) metadata: Dict[str, Any]
[docs] class TaskContext: """ Enhanced context manager for task execution with: - Context history tracking and checkpointing - Resource acquisition and release - Performance tracking - File and model tracking - Persistent storage - Metadata toggles - Enhanced error handling """
[docs] def __init__(self, data: dict, task_dir: str = "./tasks"): self.user_id = data.get("user_id", "anonymous") self.chat_id = data.get("chat_id", "none") # Ensure metadata includes aiPreference self.metadata = data.get("metadata", {}) if "aiPreference" in data: self.metadata["aiPreference"] = data["aiPreference"] self.task_dir = Path(task_dir) self.start_time: Optional[float] = None self.resources: Dict[str, Any] = {} self.context_history: List[ContextUpdate] = [] # Tracks file operations self.file_tracker: Dict[str, Dict[str, Any]] = {} self.model_tracker: List[Dict[str, Any]] = [] # Tracks model usage self.metadata_toggles: Dict[str, bool] = data.get("metadata_toggles", {}) self._initialize_persistence()
def _initialize_persistence(self): """Initialize persistent storage for context history.""" try: self.task_dir.mkdir(parents=True, exist_ok=True) history_file = self.task_dir / f"context_history_{self.chat_id}.json" if history_file.exists(): with open(history_file, "r") as f: raw_history = json.load(f) self.context_history = [ ContextUpdate( timestamp=entry["timestamp"], update_type=entry["update_type"], data=entry["data"], metadata=entry["metadata"], ) for entry in raw_history ] logger.debug(f"Loaded context history from {history_file}") except Exception as e: logger.error(f"Failed to initialize context history: {e}") async def _save_context_history(self): """Save context history to disk.""" try: history_file = self.task_dir / f"context_history_{self.chat_id}.json" serialized_history = [ { "timestamp": update.timestamp, "update_type": update.update_type, "data": update.data, "metadata": update.metadata, } for update in self.context_history ] with open(history_file, "w") as f: json.dump(serialized_history, f, indent=2) logger.debug(f"Saved context history to {history_file}") except Exception as e: logger.error(f"Failed to save context history: {e}")
[docs] def add_context_update(self, update_type: str, data: Any, metadata: Optional[Dict[str, Any]] = None): """Add a context update (e.g., message, metadata change).""" update = ContextUpdate( timestamp=time.time(), update_type=update_type, data=data, metadata=metadata or {}, ) self.context_history.append(update) logger.debug(f"Added context update: {update_type}")
[docs] def add_resource(self, name: str, resource: Any) -> None: """Add a resource that needs cleanup.""" self.resources[name] = resource self.add_context_update("resource", {"name": name}, {"type": type(resource).__name__}) logger.debug(f"Added resource: {name}")
[docs] def track_file_operation(self, file_path: str, operation: str, source: str = "task"): """Track a file operation (e.g., read, edit).""" self.file_tracker[file_path] = { "operation": operation, "source": source, "timestamp": time.time(), "state": "active", } self.add_context_update( "file_operation", {"path": file_path, "operation": operation}, {"source": source}, ) logger.debug(f"Tracked file operation: {operation} on {file_path}")
[docs] def track_model_usage(self, model_id: str, provider_id: str, mode: str): """Track AI model usage.""" model_entry = { "model_id": model_id, "provider_id": provider_id, "mode": mode, "timestamp": time.time(), } # Avoid duplicates if not self.model_tracker or self.model_tracker[-1] != model_entry: self.model_tracker.append(model_entry) self.add_context_update("model_usage", model_entry) logger.debug(f"Tracked model usage: {model_id} ({provider_id}, {mode})")
[docs] def optimize_context(self, max_size: int = 1000) -> bool: """Optimize context by removing duplicates and old entries.""" deduplicated = {} optimized_history = [] total_size = 0 for update in reversed(self.context_history): key = f"{update.update_type}:{json.dumps(update.data, sort_keys=True)}" if key not in deduplicated: deduplicated[key] = update data_size = len(str(update.data)) if total_size + data_size <= max_size: optimized_history.append(update) total_size += data_size self.context_history = list(reversed(optimized_history)) if len(deduplicated) < len(self.context_history): logger.debug(f"Optimized context: removed {len(self.context_history) - len(deduplicated)} duplicates") return True return False
[docs] async def truncate_context_history(self, timestamp: float): """Truncate context history after a given timestamp.""" original_len = len(self.context_history) self.context_history = [update for update in self.context_history if update.timestamp <= timestamp] if len(self.context_history) < original_len: await self._save_context_history() logger.debug(f"Truncated context history at timestamp {timestamp}")
[docs] def get_active_metadata(self) -> Dict[str, Any]: """Return metadata filtered by toggles.""" return {key: value for key, value in self.metadata.items() if key not in self.metadata_toggles or self.metadata_toggles[key] is not False}
[docs] def to_dict(self) -> Dict[str, Any]: """Convert context to dictionary.""" return { "user_id": self.user_id, "chat_id": self.chat_id, "metadata": self.get_active_metadata(), "context_history": [ { "timestamp": update.timestamp, "update_type": update.update_type, "data": update.data, "metadata": update.metadata, } for update in self.context_history ], "file_tracker": self.file_tracker, "model_tracker": self.model_tracker, }
[docs] def __enter__(self): """Synchronous context entry.""" self.start_time = time.time() logger.debug(f"Starting task context for user {self.user_id}, chat {self.chat_id}") return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """Synchronous context exit with cleanup.""" duration = time.time() - (self.start_time or 0.0) logger.debug(f"Completed task context in {duration:.2f}s for user {self.user_id}") for resource_name, resource in self.resources.items(): try: if hasattr(resource, "close"): resource.close() logger.debug(f"Cleaned up resource: {resource_name}") except Exception as e: logger.error(f"Error cleaning up resource {resource_name}: {e}") if exc_type: logger.error(f"Task context exited with error: {exc_val}") return False
[docs] async def __aenter__(self): """Asynchronous context entry.""" self.start_time = time.time() logger.debug(f"Starting async task context for user {self.user_id}, chat {self.chat_id}") return self
[docs] async def __aexit__(self, exc_type, exc_val, exc_tb): """Asynchronous context exit with cleanup.""" duration = time.time() - (self.start_time or 0.0) logger.debug(f"Completed async task context in {duration:.2f}s for user {self.user_id}") for resource_name, resource in self.resources.items(): try: if hasattr(resource, "close"): if callable(getattr(resource, "close")): if hasattr(resource.close, "__await__"): await resource.close() else: resource.close() logger.debug(f"Cleaned up async resource: {resource_name}") except Exception as e: logger.error(f"Error cleaning up async resource {resource_name}: {e}") if exc_type: logger.error(f"Async task context exited with error: {exc_val}") await self._save_context_history() return False
[docs] def build_context(data: dict) -> dict: """Build a simple context dictionary (for backward compatibility).""" context = TaskContext(data) return context.to_dict()
[docs] @asynccontextmanager async def task_context(data: dict, task_dir: str = "./tasks") -> AsyncGenerator[TaskContext, None]: """ Async context manager for task execution. Usage: async with task_context(request_data, task_dir="/path/to/tasks") as context: context.add_context_update("message", "User input", {"source": "user"}) context.track_file_operation("example.py", "read", "tool") result = await service_instance.run(data, context) """ context = TaskContext(data, task_dir) try: await context.__aenter__() yield context finally: await context.__aexit__(None, None, None)