"""JSON-based memory store for user feedback and session history."""

import json
import os
from datetime import datetime
from typing import Dict, Any, List, Optional

try:
    import numpy as np
    NUMPY_AVAILABLE = True
except ImportError:
    NUMPY_AVAILABLE = False
    np = None


class MemoryStore:
    """JSON-based memory store for storing user feedback and session history."""
    
    def __init__(self, path: str = "memory/history.json", user_memory_path: str = "memory/user_memory.json", 
                 intent_memory_path: str = "memory/intent_memory.json"):
        """Initialize memory store.
        
        Args:
            path: Path to the JSON file storing history
            user_memory_path: Path to the JSON file storing user vectors
            intent_memory_path: Path to the JSON file storing intent vectors
        """
        self.path = path
        self.user_memory_path = user_memory_path
        self.intent_memory_path = intent_memory_path
        # 确保目录存在
        os.makedirs(os.path.dirname(path), exist_ok=True)
        os.makedirs(os.path.dirname(user_memory_path), exist_ok=True)
        os.makedirs(os.path.dirname(intent_memory_path), exist_ok=True)
        
        # 如果文件不存在，创建初始结构
        if not os.path.exists(path):
            with open(path, "w", encoding="utf-8") as f:
                json.dump({"sessions": []}, f, indent=2, ensure_ascii=False)
        
        if not os.path.exists(user_memory_path):
            with open(user_memory_path, "w", encoding="utf-8") as f:
                json.dump({}, f, indent=2, ensure_ascii=False)
        
        if not os.path.exists(intent_memory_path):
            with open(intent_memory_path, "w", encoding="utf-8") as f:
                json.dump({}, f, indent=2, ensure_ascii=False)
    
    def load(self) -> Dict[str, Any]:
        """Load history from JSON file.
        
        Returns:
            Dictionary containing session history
        """
        try:
            with open(self.path, "r", encoding="utf-8") as f:
                return json.load(f)
        except json.JSONDecodeError as e:
            # JSON 文件可能损坏，尝试修复
            import logging
            logger = logging.getLogger(__name__)
            logger.warning(f"JSON file corrupted at {e.pos}, attempting to fix...")
            
            # 读取文件内容
            with open(self.path, "r", encoding="utf-8") as f:
                content = f.read()
            
            # 尝试移除所有 embedding 字段（它们可能包含无法序列化的数据）
            import re
            # 移除完整的 embedding 字段（包括值）
            content = re.sub(r'\"embedding\":\s*[^,\n}]*,?\n?', '', content)
            # 移除未完成的 embedding 字段
            content = re.sub(r'\"embedding\":\s*\n', '', content)
            
            try:
                data = json.loads(content)
                # 进一步清理：移除所有模式中的 embedding 字段
                if 'sessions' in data:
                    for session in data.get('sessions', []):
                        if 'shown_patterns' in session:
                            for pattern in session['shown_patterns']:
                                if isinstance(pattern, dict) and 'embedding' in pattern:
                                    del pattern['embedding']
                
                # 保存修复后的文件
                self.save(data)
                logger.info("JSON file fixed and saved")
                return data
            except json.JSONDecodeError:
                # 如果仍然无法修复，返回空结构
                logger.error("Cannot fix JSON file, returning empty structure")
                return {"sessions": []}
    
    def _clean_for_json(self, obj: Any) -> Any:
        """Recursively clean object to make it JSON serializable.
        
        Args:
            obj: Object to clean
            
        Returns:
            JSON-serializable object
        """
        if NUMPY_AVAILABLE:
            if isinstance(obj, np.ndarray):
                # Convert numpy array to list
                return obj.tolist()
            elif isinstance(obj, (np.integer, np.floating)):
                # Convert numpy scalar to Python type
                return obj.item()
        
        if isinstance(obj, dict):
            # Recursively clean dictionary
            return {k: self._clean_for_json(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            # Recursively clean list
            return [self._clean_for_json(item) for item in obj]
        else:
            return obj
    
    def save(self, data: Dict[str, Any]):
        """Save history to JSON file.
        
        Args:
            data: Dictionary containing session history
        """
        # Clean data to ensure JSON serializability
        cleaned_data = self._clean_for_json(data)
        
        with open(self.path, "w", encoding="utf-8") as f:
            json.dump(cleaned_data, f, indent=2, ensure_ascii=False)
    
    def load_user_memory(self) -> Dict[str, Any]:
        """Load user memory (vectors) from JSON file.
        
        Returns:
            Dictionary containing user vectors
        """
        with open(self.user_memory_path, "r", encoding="utf-8") as f:
            return json.load(f)
    
    def save_user_memory(self, data: Dict[str, Any]):
        """Save user memory (vectors) to JSON file.
        
        Args:
            data: Dictionary containing user vectors
        """
        with open(self.user_memory_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
    
    def add_session(self, query: str, params: Dict[str, Any], 
                   patterns: List[Dict[str, Any]], feedback: Dict[str, Any]):
        """Add a new session to history.
        
        Args:
            query: User's query
            params: Extracted mining parameters
            patterns: Patterns shown to user
            feedback: User feedback (like/dislike)
        """
        data = self.load()
        
        # Clean patterns: remove embedding and other non-serializable fields
        cleaned_patterns = []
        for pattern in patterns:
            pattern_copy = pattern.copy()
            # Remove embedding field (numpy array)
            if 'embedding' in pattern_copy:
                del pattern_copy['embedding']
            # Remove table_instances if it contains complex objects
            if 'table_instances' in pattern_copy:
                # Keep only a sample for display
                if 'table_instances_sample' not in pattern_copy:
                    pattern_copy['table_instances_sample'] = pattern_copy['table_instances'][:5] if pattern_copy['table_instances'] else []
                del pattern_copy['table_instances']
            cleaned_patterns.append(pattern_copy)
        
        # Clean feedback: ensure patterns in feedback are also cleaned
        cleaned_feedback = {}
        for key, value in feedback.items():
            if isinstance(value, list):
                cleaned_feedback[key] = []
                for item in value:
                    if isinstance(item, dict):
                        item_copy = item.copy()
                        # Remove all non-serializable fields
                        if 'embedding' in item_copy:
                            del item_copy['embedding']
                        if 'table_instances' in item_copy:
                            del item_copy['table_instances']
                        cleaned_feedback[key].append(item_copy)
                    elif isinstance(item, list):
                        # If it's a list of strings (pattern features), keep it as is
                        cleaned_feedback[key].append(item)
                    else:
                        cleaned_feedback[key].append(item)
            else:
                cleaned_feedback[key] = value
        
        session = {
            "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "query": query,
            "params": params,
            "shown_patterns": cleaned_patterns,
            "feedback": cleaned_feedback
        }
        
        data["sessions"].append(session)
        self.save(data)
    
    def get_all_sessions(self) -> List[Dict[str, Any]]:
        """Get all historical sessions.
        
        Returns:
            List of all sessions
        """
        data = self.load()
        return data.get("sessions", [])
    
    def add_positive(self, user_id: str, vector: np.ndarray):
        """Add a positive (liked) pattern vector to user memory.
        
        Args:
            user_id: User identifier
            vector: Pattern embedding vector
        """
        memory = self.load_user_memory()
        
        if user_id not in memory:
            memory[user_id] = {
                "positive": [],
                "negative": [],
                "history": []
            }
        
        # 将numpy数组转换为列表
        vector_list = vector.tolist() if isinstance(vector, np.ndarray) else vector
        memory[user_id]["positive"].append(vector_list)
        
        self.save_user_memory(memory)
    
    def add_negative(self, user_id: str, vector: np.ndarray):
        """Add a negative (disliked) pattern vector to user memory.
        
        Args:
            user_id: User identifier
            vector: Pattern embedding vector
        """
        memory = self.load_user_memory()
        
        if user_id not in memory:
            memory[user_id] = {
                "positive": [],
                "negative": [],
                "history": []
            }
        
        # 将numpy数组转换为列表
        vector_list = vector.tolist() if isinstance(vector, np.ndarray) else vector
        memory[user_id]["negative"].append(vector_list)
        
        self.save_user_memory(memory)
    
    def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Get user profile (positive and negative vectors).
        
        Args:
            user_id: User identifier
            
        Returns:
            Dictionary with 'positive' and 'negative' lists of vectors, or None if user not found
        """
        memory = self.load_user_memory()
        return memory.get(user_id)

    def save_preference_features(self, user_id: str, liked: Optional[List[str]] = None,
                                  disliked: Optional[List[str]] = None):
        """Save user preference feature names (for preference-weighted scoring, e.g. 2025 adaptive fusion).
        
        Args:
            user_id: User identifier
            liked: List of feature names user likes (e.g. POI types)
            disliked: List of feature names user dislikes
        """
        memory = self.load_user_memory()
        if user_id not in memory:
            memory[user_id] = {"positive": [], "negative": []}
        memory[user_id]["preference_features"] = {
            "like": list(liked) if liked else [],
            "dislike": list(disliked) if disliked else []
        }
        self.save_user_memory(memory)

    def get_preference_features(self, user_id: str) -> Optional[Dict[str, List[str]]]:
        """Get user preference feature names (for preference-weighted scoring).
        
        Args:
            user_id: User identifier
            
        Returns:
            Dict with 'like' and 'dislike' lists of feature names, or None if not set
        """
        memory = self.load_user_memory()
        profile = memory.get(user_id)
        if not profile:
            return None
        return profile.get("preference_features")
    
    def save_intent(self, user_id: str, u_llm: np.ndarray, intent_json: Dict[str, Any]):
        """Save intent vector and intent JSON to intent memory.
        
        Args:
            user_id: User identifier
            u_llm: LLM-generated initial user preference vector
            intent_json: Intent JSON from IntentEncoder
        """
        try:
            # Load existing intent memory
            if os.path.exists(self.intent_memory_path):
                with open(self.intent_memory_path, "r", encoding="utf-8") as f:
                    intent_memory = json.load(f)
            else:
                intent_memory = {}
            
            # Convert numpy array to list
            u_llm_list = u_llm.tolist() if isinstance(u_llm, np.ndarray) else u_llm
            
            # Save intent data
            intent_memory[user_id] = {
                "u_llm": u_llm_list,
                "intent": intent_json,
                "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }
            
            # Save to file
            with open(self.intent_memory_path, "w", encoding="utf-8") as f:
                json.dump(intent_memory, f, indent=2, ensure_ascii=False)
            
            import logging
            logger = logging.getLogger(__name__)
            logger.info(f"Saved intent for user {user_id}")
            
        except Exception as e:
            import logging
            logger = logging.getLogger(__name__)
            logger.error(f"Error saving intent: {e}", exc_info=True)
    
    def load_intent(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Load intent vector and intent JSON for a user.
        
        Args:
            user_id: User identifier
            
        Returns:
            Dictionary with 'u_llm' (vector as list) and 'intent' (JSON), or None if not found
        """
        try:
            if not os.path.exists(self.intent_memory_path):
                return None
            
            with open(self.intent_memory_path, "r", encoding="utf-8") as f:
                intent_memory = json.load(f)
            
            return intent_memory.get(user_id)
            
        except Exception as e:
            import logging
            logger = logging.getLogger(__name__)
            logger.error(f"Error loading intent: {e}", exc_info=True)
            return None
    
    def get_round_history(self, user_id: str) -> List[Dict[str, Any]]:
        """Get iteration round history for a user (for Stage4).
        
        Args:
            user_id: User identifier
            
        Returns:
            List of round history dictionaries
        """
        memory = self.load_user_memory()
        user_profile = memory.get(user_id)
        if user_profile is None:
            return []
        
        # Return history if it exists
        return user_profile.get("rounds", [])


