"""User preference learner using average vector method with positive and negative feedback."""

import numpy as np
import math
from typing import List, Optional, Dict, Any, Tuple
import logging

from memory.store import MemoryStore
from learning.embedder import PatternEmbedder

logger = logging.getLogger(__name__)


class PreferenceLearner:
    """User preference learner using average vector method with positive and negative feedback."""

    def __init__(self, decay_lambda: float = 0.05, use_preference_weighted: bool = False):
        """Initialize preference learner.

        Args:
            decay_lambda: Decay parameter for intent vector weight (α_t = e^(-λt))
            use_preference_weighted: If True, multiply similarity by preference weight (adaptive fusion, 2025-style)
        """
        self.memory = MemoryStore()
        self.embedder = PatternEmbedder()
        self.decay_lambda = decay_lambda
        self.use_preference_weighted = use_preference_weighted
    
    def build_user_vectors(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """Build user preference vectors from historical feedback.
        
        Returns:
            Tuple of (positive_vector, negative_vector), or (None, None) if no feedback
            - positive_vector: Average vector of liked patterns
            - negative_vector: Average vector of disliked patterns
        """
        data = self.memory.load()
        
        liked_patterns = []
        disliked_patterns = []
        
        # 收集所有历史反馈的模式
        for session in data.get("sessions", []):
            feedback = session.get("feedback", {})
            
            # 收集点赞的模式
            liked = feedback.get("like", [])
            for pattern_data in liked:
                if isinstance(pattern_data, dict):
                    pattern = pattern_data.get("pattern", [])
                elif isinstance(pattern_data, list):
                    pattern = pattern_data
                else:
                    continue
                
                if pattern:
                    liked_patterns.append(pattern)
            
            # 收集点踩的模式
            disliked = feedback.get("dislike", [])
            for pattern_data in disliked:
                if isinstance(pattern_data, dict):
                    pattern = pattern_data.get("pattern", [])
                elif isinstance(pattern_data, list):
                    pattern = pattern_data
                else:
                    continue
                
                if pattern:
                    disliked_patterns.append(pattern)
        
        # 构建正面兴趣向量
        positive_vector = None
        if len(liked_patterns) > 0:
            vectors = self.embedder.encode_patterns(liked_patterns)
            if vectors is not None and len(vectors) > 0:
                positive_vector = np.mean(vectors, axis=0)
                logger.debug(f"Built positive vector from {len(liked_patterns)} liked patterns")
        
        # 构建负面兴趣向量
        negative_vector = None
        if len(disliked_patterns) > 0:
            vectors = self.embedder.encode_patterns(disliked_patterns)
            if vectors is not None and len(vectors) > 0:
                negative_vector = np.mean(vectors, axis=0)
                logger.debug(f"Built negative vector from {len(disliked_patterns)} disliked patterns")
        
        if positive_vector is None and negative_vector is None:
            return None, None
        
        return positive_vector, negative_vector
    
    def build_user_vector(self) -> Optional[np.ndarray]:
        """Build user preference vector from historical feedback (backward compatibility).
        
        Returns:
            User preference vector (average of liked patterns), or None if no feedback
        """
        positive_vec, _ = self.build_user_vectors()
        return positive_vec
    
    def compute_feedback_vector(self, user_id: str) -> Optional[np.ndarray]:
        """Compute user preference vector from feedback (μ⁺ - μ⁻).
        
        Args:
            user_id: User identifier
            
        Returns:
            Feedback-based user vector (positive_mean - negative_mean), or None if no feedback
        """
        user_profile = self.memory.get_user_profile(user_id)
        if not user_profile:
            return None
        
        positive_vectors = user_profile.get("positive", [])
        negative_vectors = user_profile.get("negative", [])
        
        # Compute mean positive vector
        mean_pos = np.zeros(self.embedder.get_sentence_embedding_dimension())
        if positive_vectors:
            mean_pos = np.mean([np.array(v) for v in positive_vectors], axis=0)
        
        # Compute mean negative vector
        mean_neg = np.zeros(self.embedder.get_sentence_embedding_dimension())
        if negative_vectors:
            mean_neg = np.mean([np.array(v) for v in negative_vectors], axis=0)
        
        # User vector is (mean_pos - mean_neg)
        user_vector = mean_pos - mean_neg
        
        if np.linalg.norm(user_vector) == 0:
            return None
        
        return user_vector
    
    def get_user_vector(self, user_id: str, t: int = 0) -> Optional[np.ndarray]:
        """Get fused user preference vector combining LLM intent and feedback.
        
        Formula: u_t = α_t * u_llm + (1-α_t) * u_feedback
        where α_t = e^(-λt)
        
        Args:
            user_id: User identifier
            t: Interaction round number (0 for first interaction)
            
        Returns:
            Fused user preference vector, or None if neither intent nor feedback available
        """
        # Load LLM intent vector
        intent_data = self.memory.load_intent(user_id)
        u_llm = None
        if intent_data and intent_data.get("u_llm"):
            u_llm = np.array(intent_data["u_llm"])
        
        # Compute feedback vector
        u_feedback = self.compute_feedback_vector(user_id)
        
        # If neither available, return None
        if u_llm is None and u_feedback is None:
            return None
        
        # If only one available, return it
        if u_llm is None:
            return u_feedback
        if u_feedback is None:
            return u_llm
        
        # Compute fusion weight: α_t = e^(-λt)
        alpha_t = math.exp(-self.decay_lambda * t)
        
        # Fuse vectors: u_t = α_t * u_llm + (1-α_t) * u_feedback
        u_t = alpha_t * u_llm + (1 - alpha_t) * u_feedback
        
        logger.debug(f"Fused user vector: alpha_t={alpha_t:.4f}, t={t}")
        
        return u_t
    
    def score_patterns(self, patterns: List[Dict[str, Any]], user_id: Optional[str] = None, 
                      interaction_round: int = 0) -> Optional[List[float]]:
        """Score patterns based on user preference vectors (fused LLM intent + feedback).
        
        If user_id is provided and Stage0 is enabled, uses fused vector:
        u_t = α_t * u_llm + (1-α_t) * u_feedback
        
        Otherwise, falls back to original method:
        score = similarity_to_positive - similarity_to_negative
        
        Args:
            patterns: List of pattern dictionaries, each containing 'pattern' field
            user_id: Optional user identifier for Stage0 fusion
            interaction_round: Current interaction round number (for decay calculation)
            
        Returns:
            List of similarity scores, or None if no user vectors available
        """
        # Try Stage0 fusion if user_id is provided
        if user_id:
            fused_vector = self.get_user_vector(user_id, interaction_round)
            if fused_vector is not None:
                raw = self._score_with_vector(patterns, fused_vector)
                return self._apply_preference_weights(patterns, raw, user_id)
        
        # Fallback to original method (positive - negative)
        positive_vec, negative_vec = self.build_user_vectors()
        
        # 如果既没有正面也没有负面反馈，返回 None
        if positive_vec is None and negative_vec is None:
            return None
        
        # 提取模式列表
        pattern_lists = [p.get("pattern", []) for p in patterns]
        
        # 过滤掉空的模式列表
        valid_pattern_lists = [p_list for p_list in pattern_lists if p_list]
        if not valid_pattern_lists:
            logger.info("No valid patterns to score.")
            return [0.0] * len(patterns)
        
        # 编码模式为向量
        p_vecs = self.embedder.encode_patterns(valid_pattern_lists)
        
        if p_vecs is None:
            logger.warning("Could not encode patterns for scoring.")
            return None
        
        # 计算相似度分数（原始方法：positive - negative）
        scores = []
        
        for v in p_vecs:
            score = 0.0
            
            # 计算与正面兴趣向量的相似度（加分）
            if positive_vec is not None:
                positive_norm = np.linalg.norm(positive_vec)
                v_norm = np.linalg.norm(v)
                
                if v_norm > 0 and positive_norm > 0:
                    positive_similarity = np.dot(v, positive_vec) / (v_norm * positive_norm)
                    score += positive_similarity
            
            # 计算与负面兴趣向量的相似度（减分）
            if negative_vec is not None:
                negative_norm = np.linalg.norm(negative_vec)
                v_norm = np.linalg.norm(v)
                
                if v_norm > 0 and negative_norm > 0:
                    negative_similarity = np.dot(v, negative_vec) / (v_norm * negative_norm)
                    score -= negative_similarity  # 减去负面相似度
            
            scores.append(score)
        
        # 如果原始 patterns 包含空模式，需要将分数映射回原始列表
        final_scores = []
        valid_idx = 0
        for p_list in pattern_lists:
            if p_list:
                final_scores.append(scores[valid_idx])
                valid_idx += 1
            else:
                final_scores.append(0.0)
        
        return self._apply_preference_weights(patterns, final_scores, user_id)

    def _apply_preference_weights(self, patterns: List[Dict[str, Any]], scores: List[float],
                                   user_id: Optional[str]) -> List[float]:
        """Apply preference-weighted fusion (adaptive fusion of similarity and feature-level preference).
        Aligns with 2025 work on adaptive fusion of multi-dimensional user preferences.
        weight = 1 + (n_liked - n_disliked) / |pattern| * 0.5, clamped to [0.5, 1.5].
        """
        if not self.use_preference_weighted or not user_id or len(scores) != len(patterns):
            return scores
        pf = self.memory.get_preference_features(user_id)
        if not pf:
            return scores
        liked_set = set(pf.get("like") or [])
        disliked_set = set(pf.get("dislike") or [])
        out = []
        for i, p in enumerate(patterns):
            plist = p.get("pattern", [])
            if isinstance(plist, str):
                plist = [x.strip() for x in plist.split(",") if x.strip()]
            elif plist:
                plist = [str(x).strip() for x in plist]
            n = max(len(plist), 1)
            n_liked = sum(1 for f in plist if f in liked_set)
            n_disliked = sum(1 for f in plist if f in disliked_set)
            w = 1.0 + (n_liked - n_disliked) / n * 0.5
            w = max(0.5, min(1.5, w))
            out.append(scores[i] * w)
        return out

    def _score_with_vector(self, patterns: List[Dict[str, Any]], user_vector: np.ndarray) -> List[float]:
        """Score patterns using a single user vector.
        
        Args:
            patterns: List of pattern dictionaries
            user_vector: User preference vector
            
        Returns:
            List of similarity scores
        """
        # Extract pattern lists
        pattern_lists = [p.get("pattern", []) for p in patterns]
        
        # Filter out empty patterns
        valid_pattern_lists = [p_list for p_list in pattern_lists if p_list]
        if not valid_pattern_lists:
            return [0.0] * len(patterns)
        
        # Encode patterns to vectors
        p_vecs = self.embedder.encode_patterns(valid_pattern_lists)
        if p_vecs is None:
            return [0.0] * len(patterns)
        
        # Compute cosine similarity with user vector
        scores = []
        user_norm = np.linalg.norm(user_vector)
        
        for v in p_vecs:
            v_norm = np.linalg.norm(v)
            if v_norm > 0 and user_norm > 0:
                similarity = np.dot(v, user_vector) / (v_norm * user_norm)
                scores.append(similarity)
            else:
                scores.append(0.0)
        
        # Map scores back to original pattern list
        final_scores = []
        valid_idx = 0
        for p_list in pattern_lists:
            if p_list:
                final_scores.append(scores[valid_idx])
                valid_idx += 1
            else:
                final_scores.append(0.0)
        
        return final_scores
    
    def fuse_vector(self, llm_vec: np.ndarray, fb_vec: np.ndarray, alpha: float) -> np.ndarray:
        """Fuse LLM intent vector and feedback vector.
        
        Formula: u_t = α * u_llm + (1-α) * u_feedback
        
        Args:
            llm_vec: LLM intent vector
            fb_vec: Feedback-based vector
            alpha: Fusion weight (0-1)
            
        Returns:
            Fused user preference vector
        """
        return alpha * llm_vec + (1 - alpha) * fb_vec
    
    def update_user_vector_iterative(self, user_id: str, alpha: float = 0.6) -> Optional[np.ndarray]:
        """Update user vector iteratively (for Stage4).
        
        Gets latest feedback and fuses with LLM intent vector.
        
        Args:
            user_id: User identifier
            alpha: Fusion weight for LLM intent vs feedback
            
        Returns:
            Updated fused user preference vector
        """
        # Get latest feedback vector
        u_feedback = self.compute_feedback_vector(user_id)
        
        # Get LLM intent vector
        intent_data = self.memory.load_intent(user_id)
        u_llm = None
        if intent_data and intent_data.get("u_llm"):
            u_llm = np.array(intent_data["u_llm"])
        
        # Fuse vectors
        if u_llm is not None and u_feedback is not None:
            return self.fuse_vector(u_llm, u_feedback, alpha)
        elif u_feedback is not None:
            return u_feedback
        elif u_llm is not None:
            return u_llm
        else:
            return None

    def score_all_patterns(self, patterns: List[Dict[str, Any]], user_id: Optional[str] = None,
                           interaction_round: int = 0) -> Optional[List[float]]:
        """Score all patterns for evaluation.
        
        This is a wrapper around score_patterns for evaluation purposes.
        
        Args:
            patterns: List of pattern dictionaries
            user_id: Optional user identifier
            interaction_round: Current interaction round number
            
        Returns:
            List of similarity scores, or None if no user vectors available
        """
        return self.score_patterns(patterns, user_id=user_id, interaction_round=interaction_round)


