Source code for aiecs.llm.clients.xai_client

# /*---------------------------------------------------------------------------------------------
#  *  Copyright (c) IRETBL Corporation. All rights reserved.
#  *  Licensed under the Apache-2.0. See License.txt in the project root for license information.
#  *--------------------------------------------------------------------------------------------*/
from openai import AsyncOpenAI
from aiecs.config.config import get_settings
from aiecs.llm.clients.base_client import (
    BaseLLMClient,
    LLMMessage,
    LLMResponse,
    ProviderNotAvailableError,
    RateLimitError,
)
from aiecs.llm.clients.openai_compatible_mixin import (
    OpenAICompatibleFunctionCallingMixin,
)
from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
)
import logging
from typing import Dict, Optional, List, AsyncGenerator, Any

# Lazy import to avoid circular dependency


def _get_config_loader():
    """Lazy import of config loader to avoid circular dependency"""
    from aiecs.llm.config import get_llm_config_loader

    return get_llm_config_loader()


logger = logging.getLogger(__name__)


[docs] class XAIClient(BaseLLMClient, OpenAICompatibleFunctionCallingMixin): """xAI (Grok) provider client"""
[docs] def __init__(self) -> None: super().__init__("xAI") self.settings = get_settings() self._openai_client: Optional[AsyncOpenAI] = None self._model_map: Optional[Dict[str, str]] = None
def _get_openai_client(self) -> AsyncOpenAI: """Lazy initialization of OpenAI client for XAI""" if not self._openai_client: api_key = self._get_api_key() self._openai_client = AsyncOpenAI( api_key=api_key, base_url="https://api.x.ai/v1", timeout=360.0, # Override default timeout with longer timeout for reasoning models ) return self._openai_client def _get_api_key(self) -> str: """Get API key with backward compatibility""" # Support both xai_api_key and grok_api_key for backward compatibility api_key = getattr(self.settings, "xai_api_key", None) or getattr(self.settings, "grok_api_key", None) if not api_key: raise ProviderNotAvailableError("xAI API key not configured") return str(api_key) def _get_model_map(self) -> Dict[str, str]: """Get model mappings from configuration""" if self._model_map is None: try: loader = _get_config_loader() provider_config = loader.get_provider_config("xAI") if provider_config and provider_config.model_mappings: self._model_map = provider_config.model_mappings else: self._model_map = {} except Exception as e: self.logger.warning(f"Failed to load model mappings from config: {e}") self._model_map = {} return self._model_map
[docs] @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((Exception, RateLimitError)), ) async def generate_text( self, messages: List[LLMMessage], model: Optional[str] = None, temperature: float = 0.7, max_tokens: Optional[int] = None, context: Optional[Dict[str, Any]] = None, functions: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Any] = None, input_price: Optional[float] = None, output_price: Optional[float] = None, **kwargs, ) -> LLMResponse: """ Generate text using xAI API via OpenAI library (supports all Grok models). xAI API is OpenAI-compatible, so it supports Function Calling. Args: messages: List of conversation messages model: Model name (optional, uses default if not provided) temperature: Sampling temperature (0.0 to 1.0) max_tokens: Maximum tokens to generate context: Optional context dictionary containing metadata such as: - user_id: User identifier for tracking/billing - tenant_id: Tenant identifier for multi-tenant setups - request_id: Request identifier for tracing - session_id: Session identifier - Any other custom metadata for observability or middleware functions: List of function schemas (legacy format) tools: List of tool schemas (new format, recommended) tool_choice: Tool choice strategy **kwargs: Additional provider-specific parameters Returns: LLMResponse with generated text and metadata """ # Check API key availability api_key = self._get_api_key() if not api_key: raise ProviderNotAvailableError("xAI API key is not configured.") client = self._get_openai_client() # Get model name from config if not provided selected_model = model or self._get_default_model() or "grok-4" # Get model mappings from config model_map = self._get_model_map() api_model = model_map.get(selected_model, selected_model) try: # Use mixin method for Function Calling support response = await self._generate_text_with_function_calling( client=client, messages=messages, model=api_model, temperature=temperature, max_tokens=max_tokens, functions=functions, tools=tools, tool_choice=tool_choice, input_price=input_price, output_price=output_price, **kwargs, ) # Override provider and model name for xAI response.provider = self.provider_name response.model = selected_model return response except Exception as e: if "rate limit" in str(e).lower() or "429" in str(e): raise RateLimitError(f"xAI rate limit exceeded: {str(e)}") logger.error(f"xAI API error: {str(e)}") raise
[docs] async def stream_text( self, messages: List[LLMMessage], model: Optional[str] = None, temperature: float = 0.7, max_tokens: Optional[int] = None, context: Optional[Dict[str, Any]] = None, functions: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Any] = None, return_chunks: bool = False, input_price: Optional[float] = None, output_price: Optional[float] = None, **kwargs, ) -> AsyncGenerator[Any, None]: """ Stream text using xAI API via OpenAI library (supports all Grok models). xAI API is OpenAI-compatible, so it supports Function Calling. Args: messages: List of conversation messages model: Model name (optional, uses default if not provided) temperature: Sampling temperature (0.0 to 1.0) max_tokens: Maximum tokens to generate context: Optional context dictionary containing metadata such as: - user_id: User identifier for tracking/billing - tenant_id: Tenant identifier for multi-tenant setups - request_id: Request identifier for tracing - session_id: Session identifier - Any other custom metadata for observability or middleware functions: List of function schemas (legacy format) tools: List of tool schemas (new format, recommended) tool_choice: Tool choice strategy return_chunks: If True, returns StreamChunk objects with tool_calls info; if False, returns str tokens only **kwargs: Additional provider-specific parameters Yields: str or StreamChunk: Text tokens or StreamChunk objects """ # Check API key availability api_key = self._get_api_key() if not api_key: raise ProviderNotAvailableError("xAI API key is not configured.") client = self._get_openai_client() # Get model name from config if not provided selected_model = model or self._get_default_model() or "grok-4" # Get model mappings from config model_map = self._get_model_map() api_model = model_map.get(selected_model, selected_model) try: # Use mixin method for Function Calling support async for chunk in self._stream_text_with_function_calling( client=client, messages=messages, model=api_model, temperature=temperature, max_tokens=max_tokens, functions=functions, tools=tools, tool_choice=tool_choice, return_chunks=return_chunks, **kwargs, ): yield chunk except Exception as e: if "rate limit" in str(e).lower() or "429" in str(e): raise RateLimitError(f"xAI rate limit exceeded: {str(e)}") logger.error(f"xAI API streaming error: {str(e)}") raise
[docs] async def close(self): """Clean up resources""" if self._openai_client: await self._openai_client.close() self._openai_client = None