Source code for aiecs.tools.base_tool
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) IRETBL Corporation. All rights reserved.
# * Licensed under the Apache-2.0. See License.txt in the project root for license information.
# *--------------------------------------------------------------------------------------------*/
import inspect
import logging
from typing import Any, Dict, List, Optional, Type
from pydantic import BaseModel, ValidationError
import re
from aiecs.tools.tool_executor import (
InputValidationError,
SecurityError,
get_executor,
ExecutorConfig,
)
from aiecs.config.tool_config import get_tool_config_loader
from aiecs.tools.schema_generator import generate_schema_from_method
logger = logging.getLogger(__name__)
[docs]
class BaseTool:
"""
Base class for all tools, providing common functionality:
- Input validation with Pydantic schemas
- Caching with TTL and content-based keys
- Concurrency with async/sync execution
- Error handling with retries and context
- Performance optimization with metrics
- Logging with structured output
Tools inheriting from this class focus on business logic, leveraging
the executor's cross-cutting concerns.
Example:
class MyTool(BaseTool):
class ReadSchema(BaseModel):
path: str
@validate_input(ReadSchema)
@cache_result(ttl=300)
@run_in_executor
@measure_execution_time
@sanitize_input
def read(self, path: str):
# Implementation
pass
"""
[docs]
def __init__(self, config: Optional[Dict[str, Any]] = None, tool_name: Optional[str] = None):
"""
Initialize the tool with optional configuration.
Configuration is automatically loaded from:
1. Explicit config dict (highest priority)
2. YAML config files (config/tools/{tool_name}.yaml or config/tools.yaml)
3. Environment variables (via dotenv from .env files)
4. Tool defaults (lowest priority)
Args:
config (Dict[str, Any], optional): Tool-specific configuration that overrides
all other sources. If None, configuration is loaded automatically.
tool_name (str, optional): Registered tool name. If None, uses class name.
Raises:
ValueError: If config is invalid.
ValidationError: If config validation fails (when Config class exists).
"""
# Detect Config class if it exists
config_class = self._detect_config_class()
# Determine tool name (for config file discovery)
if tool_name is None:
tool_name = self.__class__.__name__
# Load configuration using ToolConfigLoader
if config_class:
# Tool has Config class - use loader to load and validate config
loader = get_tool_config_loader()
try:
loaded_config = loader.load_tool_config(
tool_name=tool_name,
config_schema=config_class,
explicit_config=config,
)
# Instantiate Config class with loaded config
self._config_obj: Optional[BaseModel] = config_class(**loaded_config)
self._config = loaded_config
except ValidationError as e:
logger.error(f"Configuration validation failed for {tool_name}: {e}")
raise
except Exception as e:
logger.warning(f"Failed to load configuration for {tool_name}: {e}. Using defaults.")
# Fallback to explicit config or empty dict
self._config = config or {}
try:
self._config_obj = config_class(**self._config)
except Exception:
# If even defaults fail, create empty config object
self._config_obj = None
else:
# No Config class - backward compatibility mode
# Still try to load from YAML/env if config provided, otherwise use as-is
if config:
# Use explicit config as-is
self._config = config
else:
# Try to load from YAML/env even without Config class
loader = get_tool_config_loader()
try:
self._config = loader.load_tool_config(
tool_name=tool_name,
config_schema=None,
explicit_config=None,
)
except Exception as e:
logger.debug(f"Could not load config for {tool_name}: {e}. Using empty config.")
self._config = {}
self._config_obj = None
# Extract only executor-related config fields to avoid passing tool-specific
# fields (e.g., user_agent, temp_dir) to ExecutorConfig
executor_config = self._extract_executor_config(self._config)
self._executor = get_executor(executor_config)
self._schemas: Dict[str, Type[BaseModel]] = {}
self._async_methods: List[str] = []
# Schema coverage tracking
self._schema_coverage: Dict[str, Any] = {
"total_methods": 0,
"manual_schemas": 0,
"auto_generated_schemas": 0,
"missing_schemas": 0,
"schema_quality": {},
}
self._register_schemas()
self._register_async_methods()
# Log schema coverage after registration
self._log_schema_coverage()
def _extract_executor_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract only executor-related configuration fields from the full config.
This prevents tool-specific fields (e.g., user_agent, temp_dir) from being
passed to ExecutorConfig, which would cause validation issues or be silently
ignored.
Args:
config (Dict[str, Any]): Full configuration dictionary.
Returns:
Dict[str, Any]: Filtered configuration containing only ExecutorConfig fields.
"""
if not config:
return {}
# Get all valid field names from ExecutorConfig
executor_fields = set(ExecutorConfig.model_fields.keys())
# Filter config to only include executor-related fields
executor_config = {key: value for key, value in config.items() if key in executor_fields}
return executor_config
def _detect_config_class(self) -> Optional[Type[BaseModel]]:
"""
Detect Config class in tool class hierarchy via introspection.
Looks for a class named 'Config' that inherits from BaseModel or BaseSettings.
Returns:
Config class if found, None otherwise
"""
# Check current class and all base classes
for cls in [self.__class__] + list(self.__class__.__mro__):
if hasattr(cls, "Config"):
config_attr = getattr(cls, "Config")
# Check if Config is a class and inherits from BaseModel
if isinstance(config_attr, type):
# Import BaseSettings here to avoid circular imports
try:
from pydantic_settings import BaseSettings
if issubclass(config_attr, (BaseModel, BaseSettings)):
return config_attr
except ImportError:
# Fallback if pydantic_settings not available
if issubclass(config_attr, BaseModel):
return config_attr
return None
@property
def settings(self) -> Any:
"""Backward-compatible alias for validated tool configuration (tests, legacy callers)."""
if hasattr(self, "config"):
return self.config
if self._config_obj is not None:
return self._config_obj
config_class = self._detect_config_class()
if config_class is not None:
return config_class()
return self._config
def _register_schemas(self) -> None:
"""
Register Pydantic schemas for operations by inspecting inner Schema classes.
Falls back to auto-generation when manual schemas are missing.
Example:
class MyTool(BaseTool):
class ReadSchema(BaseModel):
path: str
def read(self, path: str):
pass
# Registers 'read' -> ReadSchema (manual)
# Auto-generates schema for 'write' if WriteSchema doesn't exist
"""
# First pass: Register manual schemas
manual_schemas = {}
for attr_name in dir(self.__class__):
attr = getattr(self.__class__, attr_name)
if isinstance(attr, type) and issubclass(attr, BaseModel) and attr.__name__.endswith("Schema"):
# Normalize schema name to operation name
# Handle Method_nameSchema -> method_name convention
schema_base_name = attr.__name__.replace("Schema", "")
# Convert CamelCase to snake_case, then to lowercase
# e.g., Read_csvSchema -> read_csv
op_name = self._normalize_schema_name_to_method(schema_base_name)
manual_schemas[op_name] = attr
self._schemas[op_name] = attr
self._schema_coverage["manual_schemas"] += 1
logger.debug(f"Registered manual schema {attr.__name__} -> {op_name}")
# Second pass: Auto-generate schemas for methods without manual schemas
public_methods = self._get_public_methods()
self._schema_coverage["total_methods"] = len(public_methods)
for method_name in public_methods:
# Skip if already has manual schema
if method_name in self._schemas:
continue
# Skip async wrappers (they share schemas with sync methods)
if method_name.endswith("_async"):
sync_method_name = method_name[:-6] # Remove "_async"
if sync_method_name in self._schemas:
self._schemas[method_name] = self._schemas[sync_method_name]
logger.debug(f"Reusing schema for async method {method_name} from {sync_method_name}")
continue
# Try to auto-generate schema
method = getattr(self.__class__, method_name)
if callable(method) and not isinstance(method, type):
try:
auto_schema = generate_schema_from_method(method, method_name)
if auto_schema:
self._schemas[method_name] = auto_schema
self._schema_coverage["auto_generated_schemas"] += 1
logger.info(f"Auto-generated schema for method {method_name} -> {auto_schema.__name__}")
else:
self._schema_coverage["missing_schemas"] += 1
logger.debug(f"No schema generated for method {method_name} (no parameters)")
except Exception as e:
self._schema_coverage["missing_schemas"] += 1
logger.warning(f"Failed to auto-generate schema for {method_name}: {e}")
def _normalize_schema_name_to_method(self, schema_base_name: str) -> str:
"""
Convert schema name to method name.
Handles conventions like:
- Read_csvSchema -> read_csv
- ReadCsvSchema -> readcsv (fallback, but should use Read_csvSchema)
- ReadSchema -> read
Args:
schema_base_name: Schema name without "Schema" suffix
Returns:
Normalized method name
"""
# If name contains underscores, preserve them (e.g., Read_csv -> read_csv)
if "_" in schema_base_name:
# Convert first letter to lowercase, keep rest as-is
if schema_base_name:
return schema_base_name[0].lower() + schema_base_name[1:]
return schema_base_name.lower()
# Convert CamelCase to snake_case
# Insert underscore before uppercase letters (except first)
result = []
for i, char in enumerate(schema_base_name):
if char.isupper() and i > 0:
result.append("_")
result.append(char.lower())
return "".join(result)
def _get_public_methods(self) -> List[str]:
"""
Get list of public methods that should have schemas.
Returns:
List of method names
"""
methods = []
for attr_name in dir(self.__class__):
# Skip private methods
if attr_name.startswith("_"):
continue
# Skip base class methods
if attr_name in ["run", "run_async", "run_batch"]:
continue
attr = getattr(self.__class__, attr_name)
# Skip non-method attributes
if not callable(attr):
continue
# Skip classes (like Config, Schema, etc.)
if isinstance(attr, type):
continue
methods.append(attr_name)
return methods
def _log_schema_coverage(self) -> None:
"""
Log schema coverage metrics after registration.
"""
coverage = self._schema_coverage
total = coverage["total_methods"]
if total == 0:
return
manual = coverage["manual_schemas"]
auto = coverage["auto_generated_schemas"]
missing = coverage["missing_schemas"]
coverage_pct = ((manual + auto) / total * 100) if total > 0 else 0
logger.info(f"Schema coverage for {self.__class__.__name__}: " f"{coverage_pct:.1f}% ({manual + auto}/{total}) - " f"Manual: {manual}, Auto: {auto}, Missing: {missing}")
if missing > 0:
logger.debug(f"{missing} methods without schemas in {self.__class__.__name__}")
[docs]
def get_schema_coverage(self) -> Dict[str, Any]:
"""
Get schema coverage metrics for this tool.
Returns:
Dictionary with coverage metrics:
- total_methods: Total number of public methods
- manual_schemas: Number of manually defined schemas
- auto_generated_schemas: Number of auto-generated schemas
- missing_schemas: Number of methods without schemas
- coverage_percentage: Percentage of methods with schemas
- quality_metrics: Quality metrics for schemas
"""
total = self._schema_coverage["total_methods"]
manual = self._schema_coverage["manual_schemas"]
auto = self._schema_coverage["auto_generated_schemas"]
missing = self._schema_coverage["missing_schemas"]
coverage_pct = ((manual + auto) / total * 100) if total > 0 else 0
# Calculate quality metrics
quality_metrics = self._calculate_schema_quality()
return {
"total_methods": total,
"manual_schemas": manual,
"auto_generated_schemas": auto,
"missing_schemas": missing,
"coverage_percentage": coverage_pct,
"quality_metrics": quality_metrics,
}
def _calculate_schema_quality(self) -> Dict[str, float]:
"""
Calculate schema quality metrics.
Returns:
Dictionary with quality scores:
- description_quality: Percentage of fields with meaningful descriptions
- type_coverage: Percentage of fields with type annotations
- overall_score: Overall quality score
"""
total_fields = 0
fields_with_descriptions = 0
fields_with_types = 0
for schema in self._schemas.values():
if not hasattr(schema, "model_fields"):
continue
for field_name, field_info in schema.model_fields.items():
total_fields += 1
# Check for meaningful description (not just "Parameter {name}")
desc = field_info.description
if desc and desc != f"Parameter {field_name}":
fields_with_descriptions += 1
# Check for type annotation
if field_info.annotation is not None and field_info.annotation != Any:
fields_with_types += 1
description_quality = (fields_with_descriptions / total_fields * 100) if total_fields > 0 else 0
type_coverage = (fields_with_types / total_fields * 100) if total_fields > 0 else 0
overall_score = (description_quality + type_coverage) / 2 if total_fields > 0 else 0
return {
"description_quality": description_quality,
"type_coverage": type_coverage,
"overall_score": overall_score,
}
def _register_async_methods(self) -> None:
"""
Register async methods for proper execution handling.
"""
for attr_name in dir(self.__class__):
attr = getattr(self.__class__, attr_name)
if inspect.iscoroutinefunction(attr) and not attr_name.startswith("_"):
self._async_methods.append(attr_name)
def _sanitize_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize keyword arguments to prevent injection attacks.
Args:
kwargs (Dict[str, Any]): Input keyword arguments.
Returns:
Dict[str, Any]: Sanitized keyword arguments.
Raises:
SecurityError: If kwargs contain malicious content.
"""
sanitized = {}
for k, v in kwargs.items():
if isinstance(v, str) and re.search(r"(\bSELECT\b|\bINSERT\b|--|;|/\*)", v, re.IGNORECASE):
raise SecurityError(f"Input parameter '{k}' contains potentially malicious content")
sanitized[k] = v
return sanitized
[docs]
def run(self, op: str, **kwargs) -> Any:
"""
Execute a synchronous operation with parameters.
Args:
op (str): The name of the operation to execute.
**kwargs: The parameters to pass to the operation.
Returns:
Any: The result of the operation.
Raises:
ToolExecutionError: If the operation fails.
InputValidationError: If input parameters are invalid.
SecurityError: If inputs contain malicious content.
"""
schema_class = self._schemas.get(op)
if schema_class:
try:
schema = schema_class(**kwargs)
kwargs = schema.model_dump(exclude_unset=True)
except ValidationError as e:
raise InputValidationError(f"Invalid input parameters: {e}")
kwargs = self._sanitize_kwargs(kwargs)
return self._executor.execute(self, op, **kwargs)
[docs]
async def run_async(self, op: str, **kwargs) -> Any:
"""
Execute an asynchronous operation with parameters.
Args:
op (str): The name of the operation to execute.
**kwargs: The parameters to pass to the operation.
Returns:
Any: The result of the operation.
Raises:
ToolExecutionError: If the operation fails.
InputValidationError: If input parameters are invalid.
SecurityError: If inputs contain malicious content.
"""
schema_class = self._schemas.get(op)
if schema_class:
try:
schema = schema_class(**kwargs)
kwargs = schema.model_dump(exclude_unset=True)
except ValidationError as e:
raise InputValidationError(f"Invalid input parameters: {e}")
kwargs = self._sanitize_kwargs(kwargs)
return await self._executor.execute_async(self, op, **kwargs)
[docs]
async def run_batch(self, operations: List[Dict[str, Any]]) -> List[Any]:
"""
Execute multiple operations in parallel.
Args:
operations (List[Dict[str, Any]]): List of operation dictionaries with 'op' and 'kwargs'.
Returns:
List[Any]: List of operation results.
Raises:
ToolExecutionError: If any operation fails.
InputValidationError: If input parameters are invalid.
"""
return await self._executor.execute_batch(self, operations)
def _get_method_schema(self, method_name: str) -> Optional[Type[BaseModel]]:
"""
Get the schema for a method if it exists.
Checks registered schemas first, then tries to find manual schema,
and finally falls back to auto-generation.
Args:
method_name (str): The name of the method.
Returns:
Optional[Type[BaseModel]]: The schema class or None.
"""
# First check registered schemas (includes both manual and auto-generated)
if method_name in self._schemas:
schema = self._schemas[method_name]
# Log whether it's manual or auto-generated
schema_type = "manual" if self._is_manual_schema(method_name, schema) else "auto-generated"
logger.debug(f"Retrieved {schema_type} schema for method {method_name}")
return schema
# Try to find manual schema by convention
# Convert method_name to schema name (e.g., read_csv -> Read_csvSchema)
schema_name = self._method_name_to_schema_name(method_name)
for attr_name in dir(self.__class__):
if attr_name == schema_name:
attr = getattr(self.__class__, attr_name)
if isinstance(attr, type) and issubclass(attr, BaseModel):
# Register it for future use
self._schemas[method_name] = attr
self._schema_coverage["manual_schemas"] += 1
logger.debug(f"Found and registered manual schema {schema_name} for method {method_name}")
return attr
# Fallback to auto-generation if method exists
if hasattr(self.__class__, method_name):
method = getattr(self.__class__, method_name)
if callable(method) and not isinstance(method, type):
try:
auto_schema = generate_schema_from_method(method, method_name)
if auto_schema:
self._schemas[method_name] = auto_schema
self._schema_coverage["auto_generated_schemas"] += 1
logger.info(f"Auto-generated schema on-demand for method {method_name}")
return auto_schema
except Exception as e:
logger.debug(f"Could not auto-generate schema for {method_name}: {e}")
return None
def _method_name_to_schema_name(self, method_name: str) -> str:
"""
Convert method name to schema name following convention.
Examples:
- read_csv -> Read_csvSchema
- read -> ReadSchema
Args:
method_name: Method name in snake_case
Returns:
Schema class name
"""
# Preserve underscores: read_csv -> Read_csv
parts = method_name.split("_")
capitalized_parts = [part.capitalize() for part in parts]
return "".join(capitalized_parts) + "Schema"
def _is_manual_schema(self, method_name: str, schema: Type[BaseModel]) -> bool:
"""
Check if a schema was manually defined (not auto-generated).
Args:
method_name: Method name
schema: Schema class
Returns:
True if schema is manually defined, False if auto-generated
"""
# Check if schema exists as a class attribute
schema_name = schema.__name__
if hasattr(self.__class__, schema_name):
attr = getattr(self.__class__, schema_name)
if isinstance(attr, type) and attr == schema:
return True
return False