# /*---------------------------------------------------------------------------------------------
# * Copyright (c) IRETBL Corporation. All rights reserved.
# * Licensed under the Apache-2.0. See License.txt in the project root for license information.
# *--------------------------------------------------------------------------------------------*/
import logging
from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING
from enum import Enum
from .clients.base_client import BaseLLMClient, LLMMessage, LLMResponse
from .clients.openai_client import OpenAIClient
from .clients.vertex_client import VertexAIClient
from .clients.googleai_client import GoogleAIClient
from .clients.xai_client import XAIClient
from .clients.openrouter_client import OpenRouterClient
from .clients.anthropic_client import AnthropicVertexClient
from .clients.vertex_maas_client import VertexMaaSClient
from .callbacks.custom_callbacks import CustomAsyncCallbackHandler
if TYPE_CHECKING:
from .protocols import LLMClientProtocol
logger = logging.getLogger(__name__)
[docs]
class AIProvider(str, Enum):
OPENAI = "OpenAI"
VERTEX = "Vertex"
GOOGLEAI = "GoogleAI"
XAI = "xAI"
OPENROUTER = "OpenRouter"
ANTHROPIC_VERTEX = "AnthropicVertex"
VERTEX_MAAS = "VertexMaaS"
[docs]
class LLMClientFactory:
"""Factory for creating and managing LLM provider clients"""
_clients: Dict[AIProvider, BaseLLMClient] = {}
_custom_clients: Dict[str, "LLMClientProtocol"] = {}
[docs]
@classmethod
def register_custom_provider(cls, name: str, client: "LLMClientProtocol") -> None:
"""
Register a custom LLM client provider.
This allows registration of custom LLM clients that implement the LLMClientProtocol
without inheriting from BaseLLMClient. Custom providers can be retrieved by name
using get_client().
Args:
name: Custom provider name (e.g., "my-llm", "llama-local", "custom-gpt")
client: Client implementing LLMClientProtocol
Raises:
ValueError: If client doesn't implement LLMClientProtocol
ValueError: If name conflicts with standard AIProvider enum values
Example:
```python
# Register custom LLM client
custom_client = MyCustomLLMClient()
LLMClientFactory.register_custom_provider("my-llm", custom_client)
# Use custom client
client = LLMClientFactory.get_client("my-llm")
response = await client.generate_text(messages)
```
"""
# Import here to avoid circular dependency
from .protocols import LLMClientProtocol
# Validate protocol compliance
if not isinstance(client, LLMClientProtocol):
raise ValueError("Client must implement LLMClientProtocol. " "Required methods: generate_text, stream_text, close, get_embeddings. " "Required attribute: provider_name")
# Prevent conflicts with standard provider names
try:
AIProvider(name)
raise ValueError(f"Custom provider name '{name}' conflicts with standard AIProvider enum. " f"Please use a different name.")
except ValueError as e:
# If ValueError is raised because name is not in enum, that's good
if "conflicts with standard AIProvider" in str(e):
raise
# Otherwise, name is not in enum, proceed with registration
cls._custom_clients[name] = client
logger.info(f"Registered custom LLM provider: {name}")
[docs]
@classmethod
def get_client(cls, provider: Union[str, AIProvider]) -> Union[BaseLLMClient, "LLMClientProtocol"]:
"""
Get or create a client for the specified provider.
Supports both standard AIProvider enum values and custom provider names
registered via register_custom_provider().
Args:
provider: AIProvider enum or custom provider name string
Returns:
LLM client (BaseLLMClient for standard providers, LLMClientProtocol for custom)
Raises:
ValueError: If provider is unknown (not standard and not registered)
"""
# Check custom providers first
if isinstance(provider, str) and provider in cls._custom_clients:
return cls._custom_clients[provider]
# Handle standard providers
if isinstance(provider, str):
try:
provider = AIProvider(provider)
except ValueError:
raise ValueError(
f"Unknown provider: {provider}. "
f"Standard providers: {[p.value for p in AIProvider]}. "
f"Custom providers: {list(cls._custom_clients.keys())}. "
f"Register custom providers with LLMClientFactory.register_custom_provider()"
)
if provider not in cls._clients:
cls._clients[provider] = cls._create_client(provider)
return cls._clients[provider]
@classmethod
def _create_client(cls, provider: AIProvider) -> BaseLLMClient:
"""Create a new client instance for the provider"""
if provider == AIProvider.OPENAI:
return OpenAIClient()
elif provider == AIProvider.VERTEX:
return VertexAIClient()
elif provider == AIProvider.GOOGLEAI:
return GoogleAIClient()
elif provider == AIProvider.XAI:
return XAIClient()
elif provider == AIProvider.OPENROUTER:
return OpenRouterClient()
elif provider == AIProvider.ANTHROPIC_VERTEX:
return AnthropicVertexClient()
elif provider == AIProvider.VERTEX_MAAS:
return VertexMaaSClient()
else:
raise ValueError(f"Unsupported provider: {provider}")
[docs]
@classmethod
async def close_all(cls):
"""Close all active clients (both standard and custom)"""
# Close standard clients
for client in cls._clients.values():
try:
await client.close()
except Exception as e:
logger.error(f"Error closing client {client.provider_name}: {e}")
cls._clients.clear()
# Close custom clients
for name, custom_client in cls._custom_clients.items():
try:
await custom_client.close()
except Exception as e:
logger.error(f"Error closing custom client {name}: {e}")
cls._custom_clients.clear()
[docs]
@classmethod
async def close_client(cls, provider: Union[str, AIProvider]):
"""Close a specific client (standard or custom)"""
# Check if it's a custom provider
if isinstance(provider, str) and provider in cls._custom_clients:
try:
await cls._custom_clients[provider].close()
del cls._custom_clients[provider]
logger.info(f"Closed custom client: {provider}")
except Exception as e:
logger.error(f"Error closing custom client {provider}: {e}")
return
# Handle standard providers
if isinstance(provider, str):
try:
provider = AIProvider(provider)
except ValueError:
logger.warning(f"Unknown provider to close: {provider}")
return
if provider in cls._clients:
try:
await cls._clients[provider].close()
del cls._clients[provider]
except Exception as e:
logger.error(f"Error closing client {provider}: {e}")
[docs]
@classmethod
def reload_config(cls):
"""
Reload LLM models configuration.
This reloads the configuration from the YAML file, allowing for
hot-reloading of model settings without restarting the application.
"""
try:
from aiecs.llm.config import reload_llm_config
config = reload_llm_config()
logger.info(f"Reloaded LLM configuration: {len(config.providers)} providers")
return config
except Exception as e:
logger.error(f"Failed to reload LLM configuration: {e}")
raise
[docs]
class LLMClientManager:
"""High-level manager for LLM operations with context-aware provider selection"""
[docs]
def __init__(self):
self.factory = LLMClientFactory()
def _extract_ai_preference(self, context: Optional[Dict[str, Any]]) -> tuple[Optional[str], Optional[str]]:
"""Extract AI provider and model from context"""
if not context:
return None, None
metadata = context.get("metadata", {})
# First, check for aiPreference in metadata
ai_preference = metadata.get("aiPreference", {})
if isinstance(ai_preference, dict):
provider = ai_preference.get("provider")
model = ai_preference.get("model")
if provider is not None:
return provider, model
# Fallback to direct provider/model in metadata
provider = metadata.get("provider")
model = metadata.get("model")
return provider, model
[docs]
async def generate_text(
self,
messages: Union[str, list[LLMMessage]],
provider: Optional[Union[str, AIProvider]] = None,
model: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
**kwargs,
) -> LLMResponse:
"""
Generate text using context-aware provider selection
Args:
messages: Either a string prompt or list of LLMMessage objects
provider: AI provider to use (can be overridden by context)
model: Specific model to use (can be overridden by context)
context: TaskContext or dict containing aiPreference
temperature: Sampling temperature (0.0 to 2.0)
max_tokens: Maximum tokens to generate
callbacks: List of callback handlers to execute during LLM calls
**kwargs: Additional provider-specific parameters
Returns:
LLMResponse object with generated text and metadata
"""
# Extract provider/model from context if available
context_provider, context_model = self._extract_ai_preference(context)
# Use context preferences if available, otherwise use provided values
final_provider = context_provider or provider or AIProvider.OPENAI
final_model = context_model or model
# Convert string prompt to messages format and handle None
if messages is None:
messages = []
elif isinstance(messages, str):
messages = [LLMMessage(role="user", content=messages)]
# Execute on_llm_start callbacks
if callbacks:
# Convert LLMMessage objects to dictionaries for callbacks
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
for callback in callbacks:
try:
await callback.on_llm_start(
messages_dict,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as e:
logger.error(f"Error in callback on_llm_start: {e}")
try:
# Get the appropriate client
client = self.factory.get_client(final_provider)
# Generate text
response = await client.generate_text(
messages=messages,
model=final_model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
# Execute on_llm_end callbacks
if callbacks:
# Convert LLMResponse object to dictionary for callbacks
response_dict = {
"content": response.content,
"provider": response.provider,
"model": response.model,
"tokens_used": response.tokens_used,
"prompt_tokens": response.prompt_tokens,
"completion_tokens": response.completion_tokens,
"cost_estimate": response.cost_estimate,
"response_time": response.response_time,
}
for callback in callbacks:
try:
await callback.on_llm_end(
response_dict,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as e:
logger.error(f"Error in callback on_llm_end: {e}")
logger.info(f"Generated text using {final_provider}/{response.model}")
return response
except Exception as e:
# Execute on_llm_error callbacks
if callbacks:
for callback in callbacks:
try:
await callback.on_llm_error(
e,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as callback_error:
logger.error(f"Error in callback on_llm_error: {callback_error}")
# Re-raise the original exception
raise
[docs]
async def stream_text(
self,
messages: Union[str, list[LLMMessage]],
provider: Optional[Union[str, AIProvider]] = None,
model: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
**kwargs,
):
"""
Stream text generation using context-aware provider selection
Args:
messages: Either a string prompt or list of LLMMessage objects
provider: AI provider to use (can be overridden by context)
model: Specific model to use (can be overridden by context)
context: TaskContext or dict containing aiPreference
temperature: Sampling temperature (0.0 to 2.0)
max_tokens: Maximum tokens to generate
callbacks: List of callback handlers to execute during LLM calls
**kwargs: Additional provider-specific parameters
Yields:
str: Incremental text chunks
"""
# Extract provider/model from context if available
context_provider, context_model = self._extract_ai_preference(context)
# Use context preferences if available, otherwise use provided values
final_provider = context_provider or provider or AIProvider.OPENAI
final_model = context_model or model
# Convert string prompt to messages format and handle None
if messages is None:
messages = []
elif isinstance(messages, str):
messages = [LLMMessage(role="user", content=messages)]
# Execute on_llm_start callbacks
if callbacks:
# Convert LLMMessage objects to dictionaries for callbacks
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] if messages else []
for callback in callbacks:
try:
await callback.on_llm_start(
messages_dict,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as e:
logger.error(f"Error in callback on_llm_start: {e}")
try:
# Get the appropriate client
client = self.factory.get_client(final_provider)
# Collect streamed content for token counting
collected_content = ""
# Stream text
stream_gen = client.stream_text(
messages=messages,
model=final_model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
async for chunk in stream_gen:
# Handle StreamChunk objects (when return_chunks=True or function calling)
if hasattr(chunk, "content") and chunk.content:
collected_content += chunk.content
elif isinstance(chunk, str):
collected_content += chunk
yield chunk
# Create a response object for callbacks (streaming doesn't return LLMResponse directly)
# We need to estimate token usage for streaming responses
estimated_tokens = len(collected_content) // 4 # Rough estimation
stream_response = LLMResponse(
content=collected_content,
provider=str(final_provider),
model=final_model or "unknown",
tokens_used=estimated_tokens,
)
# Execute on_llm_end callbacks
if callbacks:
# Convert LLMResponse object to dictionary for callbacks
response_dict = {
"content": stream_response.content,
"provider": stream_response.provider,
"model": stream_response.model,
"tokens_used": stream_response.tokens_used,
"prompt_tokens": stream_response.prompt_tokens,
"completion_tokens": stream_response.completion_tokens,
"cost_estimate": stream_response.cost_estimate,
"response_time": stream_response.response_time,
}
for callback in callbacks:
try:
await callback.on_llm_end(
response_dict,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as e:
logger.error(f"Error in callback on_llm_end: {e}")
except Exception as e:
# Execute on_llm_error callbacks
if callbacks:
for callback in callbacks:
try:
await callback.on_llm_error(
e,
provider=final_provider,
model=final_model,
**kwargs,
)
except Exception as callback_error:
logger.error(f"Error in callback on_llm_error: {callback_error}")
# Re-raise the original exception
raise
[docs]
async def close(self):
"""Close all clients"""
await self.factory.close_all()
# Global instance for easy access
_llm_manager = LLMClientManager()
[docs]
async def get_llm_manager() -> LLMClientManager:
"""Get the global LLM manager instance"""
return _llm_manager
# Convenience functions for backward compatibility
[docs]
async def generate_text(
messages: Union[str, list[LLMMessage]],
provider: Optional[Union[str, AIProvider]] = None,
model: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
**kwargs,
) -> LLMResponse:
"""Generate text using the global LLM manager"""
manager = await get_llm_manager()
return await manager.generate_text(
messages,
provider,
model,
context,
temperature,
max_tokens,
callbacks,
**kwargs,
)
[docs]
async def stream_text(
messages: Union[str, list[LLMMessage]],
provider: Optional[Union[str, AIProvider]] = None,
model: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
callbacks: Optional[List[CustomAsyncCallbackHandler]] = None,
**kwargs,
):
"""Stream text using the global LLM manager"""
manager = await get_llm_manager()
async for chunk in manager.stream_text(
messages,
provider,
model,
context,
temperature,
max_tokens,
callbacks,
**kwargs,
):
yield chunk