# /*---------------------------------------------------------------------------------------------
# * Copyright (c) IRETBL Corporation. All rights reserved.
# * Licensed under the Apache-2.0. See License.txt in the project root for license information.
# *--------------------------------------------------------------------------------------------*/
import asyncio
import json
import logging
import uuid
import websockets
from typing import Dict, Any, Set, Optional, Callable
from websockets import serve, ServerConnection
from pydantic import BaseModel
logger = logging.getLogger(__name__)
[docs]
class UserConfirmation(BaseModel):
proceed: bool
feedback: Optional[str] = None
[docs]
class TaskStepResult(BaseModel):
step: str
result: Any = None
completed: bool = False
message: str
status: str
error_code: Optional[str] = None
error_message: Optional[str] = None
[docs]
class WebSocketManager:
"""
Specialized handler for WebSocket server and client communication
"""
[docs]
def __init__(self, host: str = "python-middleware-api", port: int = 8765):
self.host = host
self.port = port
self.server: Optional[Any] = None
self.callback_registry: Dict[str, Callable] = {}
self.active_connections: Set[ServerConnection] = set()
self._running = False
[docs]
async def start_server(self):
"""Start WebSocket server"""
if self.server:
logger.warning("WebSocket server is already running")
return self.server
try:
self.server = await serve(self._handle_client_connection, self.host, self.port)
self._running = True
logger.info(f"WebSocket server started on {self.host}:{self.port}")
return self.server
except Exception as e:
logger.error(f"Failed to start WebSocket server: {e}")
raise
[docs]
async def stop_server(self):
"""Stop WebSocket server"""
if self.server:
self.server.close()
await self.server.wait_closed()
self._running = False
logger.info("WebSocket server stopped")
# Close all active connections
if self.active_connections:
await asyncio.gather(
*[conn.close() for conn in self.active_connections],
return_exceptions=True,
)
self.active_connections.clear()
async def _handle_client_connection(self, websocket: ServerConnection):
"""Handle client connection"""
self.active_connections.add(websocket)
client_addr = websocket.remote_address
logger.info(f"New WebSocket connection from {client_addr}")
try:
async for message in websocket:
# Decode bytes to str if needed
message_str = message if isinstance(message, str) else message.decode("utf-8")
await self._handle_client_message(websocket, message_str)
except websockets.exceptions.ConnectionClosed:
logger.info(f"WebSocket connection closed: {client_addr}")
except Exception as e:
logger.error(f"WebSocket error for {client_addr}: {e}")
finally:
self.active_connections.discard(websocket)
try:
await websocket.close()
except Exception:
pass # Connection already closed
async def _handle_client_message(self, websocket: ServerConnection, message: str):
"""Handle client message"""
try:
data = json.loads(message)
action = data.get("action")
if action == "confirm":
await self._handle_confirmation(data)
elif action == "cancel":
await self._handle_cancellation(data)
elif action == "ping":
await self._handle_ping(websocket, data)
elif action == "subscribe":
await self._handle_subscription(websocket, data)
else:
logger.warning(f"Unknown action received: {action}")
await self._send_error(websocket, f"Unknown action: {action}")
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON received: {e}")
await self._send_error(websocket, "Invalid JSON format")
except Exception as e:
logger.error(f"Error handling client message: {e}")
await self._send_error(websocket, f"Internal error: {str(e)}")
async def _handle_confirmation(self, data: Dict[str, Any]):
"""Handle user confirmation"""
callback_id = data.get("callback_id")
if callback_id and callback_id in self.callback_registry:
callback = self.callback_registry[callback_id]
confirmation = UserConfirmation(
proceed=data.get("proceed", False),
feedback=data.get("feedback"),
)
try:
callback(confirmation)
del self.callback_registry[callback_id]
logger.debug(f"Processed confirmation for callback {callback_id}")
except Exception as e:
logger.error(f"Error processing confirmation callback: {e}")
else:
logger.warning(f"No callback found for confirmation ID: {callback_id}")
async def _handle_cancellation(self, data: Dict[str, Any]):
"""Handle task cancellation"""
user_id = data.get("user_id")
task_id = data.get("task_id")
if user_id and task_id:
# Task cancellation logic can be added here
# Since database manager access is needed, this functionality may
# need to be implemented through callbacks
logger.info(f"Task cancellation requested: user={user_id}, task={task_id}")
await self.broadcast_message(
{
"type": "task_cancelled",
"user_id": user_id,
"task_id": task_id,
"timestamp": asyncio.get_event_loop().time(),
}
)
else:
logger.warning("Invalid cancellation request: missing user_id or task_id")
async def _handle_ping(self, websocket: ServerConnection, data: Dict[str, Any]):
"""Handle heartbeat detection"""
pong_data = {
"type": "pong",
"timestamp": asyncio.get_event_loop().time(),
"original_data": data,
}
await self._send_to_client(websocket, pong_data)
async def _handle_subscription(self, websocket: ServerConnection, data: Dict[str, Any]):
"""Handle subscription request"""
user_id = data.get("user_id")
if user_id:
# User-specific subscription logic can be implemented here
logger.info(f"User {user_id} subscribed to updates")
await self._send_to_client(
websocket,
{"type": "subscription_confirmed", "user_id": user_id},
)
async def _send_error(self, websocket: ServerConnection, error_message: str):
"""Send error message to client"""
error_data = {
"type": "error",
"message": error_message,
"timestamp": asyncio.get_event_loop().time(),
}
await self._send_to_client(websocket, error_data)
async def _send_to_client(self, websocket: ServerConnection, data: Dict[str, Any]):
"""Send data to specific client"""
try:
await websocket.send(json.dumps(data))
except Exception as e:
logger.error(f"Failed to send message to client: {e}")
[docs]
async def notify_user(
self,
step_result: TaskStepResult,
user_id: str,
task_id: str,
step: int,
) -> UserConfirmation:
"""Notify user of task step result"""
callback_id = str(uuid.uuid4())
confirmation_future: asyncio.Future[UserConfirmation] = asyncio.Future()
# Register callback
self.callback_registry[callback_id] = lambda confirmation: confirmation_future.set_result(confirmation)
# Prepare notification data
notification_data = {
"type": "task_step_result",
"callback_id": callback_id,
"step": step,
"message": step_result.message,
"result": step_result.result,
"status": step_result.status,
"error_code": step_result.error_code,
"error_message": step_result.error_message,
"user_id": user_id,
"task_id": task_id,
"timestamp": asyncio.get_event_loop().time(),
}
try:
# Broadcast to all connected clients (can be optimized to send only
# to specific users)
await self.broadcast_message(notification_data)
# Wait for user confirmation with timeout
try:
# 5 minute timeout
return await asyncio.wait_for(confirmation_future, timeout=300)
except asyncio.TimeoutError:
logger.warning(f"User confirmation timeout for callback {callback_id}")
# Clean up callback
self.callback_registry.pop(callback_id, None)
return UserConfirmation(proceed=True) # Default to proceed
except Exception as e:
logger.error(f"WebSocket notification error: {e}")
# Clean up callback
self.callback_registry.pop(callback_id, None)
return UserConfirmation(proceed=True) # Default to proceed
[docs]
async def send_heartbeat(self, user_id: str, task_id: str, interval: int = 30):
"""Send heartbeat message"""
heartbeat_data = {
"type": "heartbeat",
"status": "heartbeat",
"message": "Task is still executing...",
"user_id": user_id,
"task_id": task_id,
"timestamp": asyncio.get_event_loop().time(),
}
while self._running:
try:
await self.broadcast_message(heartbeat_data)
await asyncio.sleep(interval)
except Exception as e:
logger.error(f"WebSocket heartbeat error: {e}")
break
[docs]
async def broadcast_message(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.active_connections:
logger.debug("No active WebSocket connections for broadcast")
return
# Filter out closed connections (use try-except to handle closed connections)
active_connections = []
for conn in list(self.active_connections):
try:
# Try to check if connection is still valid
active_connections.append(conn)
except Exception:
pass # Connection is closed, skip it
self.active_connections = set(active_connections)
if active_connections:
await asyncio.gather(
*[self._send_to_client(conn, message) for conn in active_connections],
return_exceptions=True,
)
logger.debug(f"Broadcasted message to {len(active_connections)} clients")
[docs]
async def send_to_user(self, user_id: str, message: Dict[str, Any]):
"""Send message to specific user (requires user connection mapping implementation)"""
# User ID to WebSocket connection mapping can be implemented here
# Currently simplified to broadcast
message["target_user_id"] = user_id
await self.broadcast_message(message)
[docs]
def get_connection_count(self) -> int:
"""Get active connection count"""
return len(self.active_connections)
[docs]
def get_status(self) -> Dict[str, Any]:
"""Get WebSocket manager status"""
return {
"running": self._running,
"host": self.host,
"port": self.port,
"active_connections": self.get_connection_count(),
"pending_callbacks": len(self.callback_registry),
}