Spaces:
Runtime error
Runtime error
| """Session management service for chat agent.""" | |
| import json | |
| import logging | |
| from datetime import datetime, timedelta | |
| from typing import Optional, List, Dict, Any | |
| from uuid import uuid4 | |
| import redis | |
| from sqlalchemy.exc import SQLAlchemyError | |
| from ..models.chat_session import ChatSession | |
| from ..models.base import db | |
| logger = logging.getLogger(__name__) | |
| class SessionManagerError(Exception): | |
| """Base exception for session manager errors.""" | |
| pass | |
| class SessionNotFoundError(SessionManagerError): | |
| """Raised when a session is not found.""" | |
| pass | |
| class SessionExpiredError(SessionManagerError): | |
| """Raised when a session has expired.""" | |
| pass | |
| class SessionManager: | |
| """Manages user chat sessions with Redis caching and PostgreSQL persistence.""" | |
| def __init__(self, redis_client: redis.Redis, session_timeout: int = 3600): | |
| """ | |
| Initialize the session manager. | |
| Args: | |
| redis_client: Redis client instance for caching | |
| session_timeout: Session timeout in seconds (default: 1 hour) | |
| """ | |
| self.redis_client = redis_client | |
| self.session_timeout = session_timeout | |
| self.cache_prefix = "session:" | |
| self.user_sessions_prefix = "user_sessions:" | |
| def create_session(self, user_id: str, language: str = 'python', | |
| session_metadata: Optional[Dict[str, Any]] = None) -> ChatSession: | |
| """ | |
| Create a new chat session. | |
| Args: | |
| user_id: User identifier | |
| language: Programming language for the session (default: python) | |
| session_metadata: Additional session metadata | |
| Returns: | |
| ChatSession: The created session | |
| Raises: | |
| SessionManagerError: If session creation fails | |
| """ | |
| try: | |
| # Create session in database | |
| session = ChatSession.create_session( | |
| user_id=user_id, | |
| language=language, | |
| session_metadata=session_metadata or {} | |
| ) | |
| # Cache session in Redis | |
| self._cache_session(session) | |
| # Add session to user's session list | |
| self._add_to_user_sessions(user_id, session.id) | |
| logger.info(f"Created new session {session.id} for user {user_id}") | |
| return session | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error creating session: {e}") | |
| raise SessionManagerError(f"Failed to create session: {e}") | |
| except redis.RedisError as e: | |
| logger.error(f"Redis error caching session: {e}") | |
| # Session was created in DB, continue without cache | |
| return session | |
| def get_session(self, session_id: str) -> ChatSession: | |
| """ | |
| Get a session by ID, checking cache first then database. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| ChatSession: The session object | |
| Raises: | |
| SessionNotFoundError: If session doesn't exist | |
| SessionExpiredError: If session has expired | |
| """ | |
| # Try to get from cache first | |
| cached_session = self._get_cached_session(session_id) | |
| if cached_session: | |
| # Check if session is expired | |
| if self._is_session_expired(cached_session): | |
| self._expire_session(session_id) | |
| raise SessionExpiredError(f"Session {session_id} has expired") | |
| return cached_session | |
| # Get from database | |
| try: | |
| session = db.session.query(ChatSession).filter( | |
| ChatSession.id == session_id, | |
| ChatSession.is_active == True | |
| ).first() | |
| if not session: | |
| raise SessionNotFoundError(f"Session {session_id} not found") | |
| # Check if session is expired | |
| if session.is_expired(self.session_timeout): | |
| session.deactivate() | |
| raise SessionExpiredError(f"Session {session_id} has expired") | |
| # Cache the session | |
| self._cache_session(session) | |
| return session | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error getting session {session_id}: {e}") | |
| raise SessionManagerError(f"Failed to get session: {e}") | |
| def update_session_activity(self, session_id: str) -> None: | |
| """ | |
| Update session activity timestamp. | |
| Args: | |
| session_id: Session identifier | |
| Raises: | |
| SessionNotFoundError: If session doesn't exist | |
| """ | |
| try: | |
| session = self.get_session(session_id) | |
| session.update_activity() | |
| # Update cache | |
| self._cache_session(session) | |
| logger.debug(f"Updated activity for session {session_id}") | |
| except (SessionNotFoundError, SessionExpiredError): | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating session activity: {e}") | |
| raise SessionManagerError(f"Failed to update session activity: {e}") | |
| def get_user_sessions(self, user_id: str, active_only: bool = True) -> List[ChatSession]: | |
| """ | |
| Get all sessions for a user. | |
| Args: | |
| user_id: User identifier | |
| active_only: Whether to return only active sessions | |
| Returns: | |
| List[ChatSession]: List of user sessions | |
| """ | |
| try: | |
| query = db.session.query(ChatSession).filter(ChatSession.user_id == user_id) | |
| if active_only: | |
| query = query.filter(ChatSession.is_active == True) | |
| sessions = query.order_by(ChatSession.last_active.desc()).all() | |
| # Filter out expired sessions | |
| if active_only: | |
| active_sessions = [] | |
| for session in sessions: | |
| if not session.is_expired(self.session_timeout): | |
| active_sessions.append(session) | |
| else: | |
| # Mark as inactive | |
| session.deactivate() | |
| self._remove_from_cache(session.id) | |
| return active_sessions | |
| return sessions | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error getting user sessions: {e}") | |
| raise SessionManagerError(f"Failed to get user sessions: {e}") | |
| def cleanup_inactive_sessions(self) -> int: | |
| """ | |
| Clean up inactive and expired sessions. | |
| Returns: | |
| int: Number of sessions cleaned up | |
| """ | |
| try: | |
| # Clean up expired sessions in database | |
| cleaned_count = ChatSession.cleanup_expired_sessions(self.session_timeout) | |
| # Clean up expired sessions from cache | |
| self._cleanup_expired_cache_sessions() | |
| logger.info(f"Cleaned up {cleaned_count} expired sessions") | |
| return cleaned_count | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error during cleanup: {e}") | |
| raise SessionManagerError(f"Failed to cleanup sessions: {e}") | |
| def delete_session(self, session_id: str) -> None: | |
| """ | |
| Delete a session completely. | |
| Args: | |
| session_id: Session identifier | |
| Raises: | |
| SessionNotFoundError: If session doesn't exist | |
| """ | |
| try: | |
| session = db.session.query(ChatSession).filter( | |
| ChatSession.id == session_id | |
| ).first() | |
| if not session: | |
| raise SessionNotFoundError(f"Session {session_id} not found") | |
| user_id = session.user_id | |
| # Delete from database (cascade will handle related records) | |
| db.session.delete(session) | |
| db.session.commit() | |
| # Remove from cache | |
| self._remove_from_cache(session_id) | |
| # Remove from user sessions list | |
| self._remove_from_user_sessions(user_id, session_id) | |
| logger.info(f"Deleted session {session_id}") | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error deleting session: {e}") | |
| raise SessionManagerError(f"Failed to delete session: {e}") | |
| def set_session_language(self, session_id: str, language: str) -> None: | |
| """ | |
| Set the programming language for a session. | |
| Args: | |
| session_id: Session identifier | |
| language: Programming language | |
| Raises: | |
| SessionNotFoundError: If session doesn't exist | |
| """ | |
| try: | |
| session = self.get_session(session_id) | |
| session.set_language(language) | |
| # Update cache | |
| self._cache_session(session) | |
| logger.info(f"Set language to {language} for session {session_id}") | |
| except (SessionNotFoundError, SessionExpiredError): | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error setting session language: {e}") | |
| raise SessionManagerError(f"Failed to set session language: {e}") | |
| def increment_message_count(self, session_id: str) -> None: | |
| """ | |
| Increment the message count for a session. | |
| Args: | |
| session_id: Session identifier | |
| """ | |
| try: | |
| session = self.get_session(session_id) | |
| session.increment_message_count() | |
| # Update cache | |
| self._cache_session(session) | |
| except (SessionNotFoundError, SessionExpiredError): | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error incrementing message count: {e}") | |
| raise SessionManagerError(f"Failed to increment message count: {e}") | |
| def _cache_session(self, session: ChatSession) -> None: | |
| """Cache a session in Redis.""" | |
| if not self.redis_client: | |
| return # Skip caching if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session.id}" | |
| session_data = { | |
| 'id': session.id, | |
| 'user_id': session.user_id, | |
| 'language': session.language, | |
| 'created_at': session.created_at.isoformat(), | |
| 'last_active': session.last_active.isoformat(), | |
| 'message_count': session.message_count, | |
| 'is_active': session.is_active, | |
| 'session_metadata': session.session_metadata | |
| } | |
| # Set with expiration | |
| self.redis_client.setex( | |
| cache_key, | |
| self.session_timeout + 300, # Add 5 minutes buffer | |
| json.dumps(session_data) | |
| ) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to cache session {session.id}: {e}") | |
| def _get_cached_session(self, session_id: str) -> Optional[ChatSession]: | |
| """Get a session from Redis cache.""" | |
| if not self.redis_client: | |
| return None # Skip cache lookup if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| cached_data = self.redis_client.get(cache_key) | |
| if not cached_data: | |
| return None | |
| session_data = json.loads(cached_data) | |
| # Create a ChatSession object from cached data | |
| session = ChatSession( | |
| user_id=session_data['user_id'], | |
| language=session_data['language'], | |
| session_metadata=session_data['session_metadata'] | |
| ) | |
| session.id = session_data['id'] | |
| session.created_at = datetime.fromisoformat(session_data['created_at']) | |
| session.last_active = datetime.fromisoformat(session_data['last_active']) | |
| session.message_count = session_data['message_count'] | |
| session.is_active = session_data['is_active'] | |
| return session | |
| except (redis.RedisError, json.JSONDecodeError, KeyError) as e: | |
| logger.warning(f"Failed to get cached session {session_id}: {e}") | |
| return None | |
| def _remove_from_cache(self, session_id: str) -> None: | |
| """Remove a session from Redis cache.""" | |
| if not self.redis_client: | |
| return # Skip cache removal if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| self.redis_client.delete(cache_key) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to remove session {session_id} from cache: {e}") | |
| def _add_to_user_sessions(self, user_id: str, session_id: str) -> None: | |
| """Add session to user's session list in Redis.""" | |
| if not self.redis_client: | |
| return # Skip user session tracking if Redis is not available | |
| try: | |
| user_sessions_key = f"{self.user_sessions_prefix}{user_id}" | |
| self.redis_client.sadd(user_sessions_key, session_id) | |
| # Set expiration for user sessions list | |
| self.redis_client.expire(user_sessions_key, self.session_timeout * 2) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to add session to user sessions list: {e}") | |
| def _remove_from_user_sessions(self, user_id: str, session_id: str) -> None: | |
| """Remove session from user's session list in Redis.""" | |
| if not self.redis_client: | |
| return # Skip user session tracking if Redis is not available | |
| try: | |
| user_sessions_key = f"{self.user_sessions_prefix}{user_id}" | |
| self.redis_client.srem(user_sessions_key, session_id) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to remove session from user sessions list: {e}") | |
| def _is_session_expired(self, session: ChatSession) -> bool: | |
| """Check if a session is expired.""" | |
| return session.is_expired(self.session_timeout) | |
| def _expire_session(self, session_id: str) -> None: | |
| """Mark a session as expired and clean up.""" | |
| try: | |
| # Mark as inactive in database | |
| session = db.session.query(ChatSession).filter( | |
| ChatSession.id == session_id | |
| ).first() | |
| if session: | |
| session.deactivate() | |
| self._remove_from_user_sessions(session.user_id, session_id) | |
| # Remove from cache | |
| self._remove_from_cache(session_id) | |
| except SQLAlchemyError as e: | |
| logger.error(f"Error expiring session {session_id}: {e}") | |
| def _cleanup_expired_cache_sessions(self) -> None: | |
| """Clean up expired sessions from Redis cache.""" | |
| try: | |
| # Get all session keys | |
| pattern = f"{self.cache_prefix}*" | |
| session_keys = self.redis_client.keys(pattern) | |
| expired_keys = [] | |
| for key in session_keys: | |
| try: | |
| cached_data = self.redis_client.get(key) | |
| if cached_data: | |
| session_data = json.loads(cached_data) | |
| last_active = datetime.fromisoformat(session_data['last_active']) | |
| if datetime.utcnow() - last_active > timedelta(seconds=self.session_timeout): | |
| expired_keys.append(key) | |
| except (json.JSONDecodeError, KeyError, ValueError): | |
| # Invalid data, mark for deletion | |
| expired_keys.append(key) | |
| # Delete expired keys | |
| if expired_keys: | |
| self.redis_client.delete(*expired_keys) | |
| logger.info(f"Cleaned up {len(expired_keys)} expired cache entries") | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to cleanup expired cache sessions: {e}") | |
| def create_session_manager(redis_client: redis.Redis, session_timeout: int = 3600) -> SessionManager: | |
| """ | |
| Factory function to create a SessionManager instance. | |
| Args: | |
| redis_client: Redis client instance | |
| session_timeout: Session timeout in seconds | |
| Returns: | |
| SessionManager: Configured session manager instance | |
| """ | |
| return SessionManager(redis_client, session_timeout) |