"""Interaction evaluator for preference learning evaluation."""

import logging
from typing import Dict, List, Any, Optional
import numpy as np

from experiment.simulator import UserSimulator
from experiment.metrics import accuracy, precision, recall, f1, compute_threshold
from experiment.sampler import DiversitySampler

logger = logging.getLogger(__name__)


class InteractionEvaluator:
    """Evaluator for interactive preference learning."""

    def __init__(
        self,
        manager,
        simulator: UserSimulator,
        top_k: int = 5,
        rounds: int = 10,
        sampling_strategy: str = "mixed"
    ):
        """Initialize interaction evaluator.
        
        Args:
            manager: PipelineManager instance
            simulator: UserSimulator instance
            top_k: Number of patterns to recommend each round
            rounds: Number of interaction rounds
            sampling_strategy: Pattern sampling strategy
                - "top": Select top-k highest scores (original)
                - "mixed": Mix of top, middle, and random (recommended)
                - "stratified": Stratified sampling by score ranges
                - "uncertainty": Select patterns near decision boundary
        """
        self.manager = manager
        self.simulator = simulator
        self.top_k = top_k
        self.rounds = rounds
        self.sampling_strategy = sampling_strategy
        self.history = []  # Store accuracy history
        self.precision_history = []  # Store precision history
        self.recall_history = []  # Store recall history
        self.f1_history = []  # Store F1 history
        self.pre_interaction_metrics = None  # Metrics before any feedback iteration
        
        # Cache for all patterns and ground truth
        self.all_patterns = None
        self.ground_truth = None
        
        logger.info(f"InteractionEvaluator initialized: top_k={top_k}, rounds={rounds}, sampling_strategy={sampling_strategy}")

    def _score_patterns_with_manager(self, patterns, user_id, interaction_round):
        """Score patterns using manager's method which handles contrastive learning properly.
        
        This ensures that contrastive learning model is used when available.
        IMPORTANT: Use the UPDATED user vector that includes latest feedback.
        """
        from preference.scorer import score_patterns, build_user_vector
        
        # IMPORTANT: Get the FUSED user vector (LLM intent + feedback) using learner's method
        # This ensures we use the updated vector that reflects all feedback so far
        fused_vector = self.manager.learner.get_user_vector(user_id, interaction_round)
        
        if fused_vector is not None:
            # Use fused vector directly (this includes LLM intent + all feedback)
            logger.debug(f"Using fused user vector (dim={len(fused_vector)}, norm={np.linalg.norm(fused_vector):.4f})")
            user_vec = fused_vector
        else:
            # Fallback: build from user profile
            user_profile = self.manager.memory.get_user_profile(user_id)
            if not user_profile:
                # No feedback yet, try using fused vector from learner
                return self.manager.learner.score_all_patterns(
                    patterns,
                    user_id=user_id,
                    interaction_round=interaction_round
                )
            
            user_vec = build_user_vector(
                user_profile, 
                dim=self.manager.embedder.model.get_sentence_embedding_dimension()
            )
            logger.debug(f"Built user vector from profile (dim={len(user_vec)}, norm={np.linalg.norm(user_vec):.4f})")
        
        # Check if contrastive learning is enabled and model is available
        # Important: Use trainer.model which is updated after training
        use_contrastive = (self.manager.stage3_use_contrastive and 
                          self.manager.trainer is not None and
                          self.manager.trainer.model is not None)
        
        # Ensure patterns have embeddings
        patterns_with_embeddings = []
        for pattern in patterns:
            pattern_copy = pattern.copy()
            if 'embedding' not in pattern_copy:
                # Encode pattern if not already encoded
                pattern_list = pattern.get('pattern', [])
                if isinstance(pattern_list, str):
                    pattern_list = pattern_list.split(',')
                pattern_vec = self.manager.embedder.encode_pattern(pattern_list)
                if pattern_vec is not None:
                    pattern_copy['embedding'] = pattern_vec
            patterns_with_embeddings.append(pattern_copy)
        
        if use_contrastive:
            # Use trained contrastive model
            logger.info(f"Scoring {len(patterns)} patterns with contrastive learning model")
            scores = score_patterns(
                patterns_with_embeddings,
                user_vec,
                model=self.manager.trainer.model,  # Use trainer's model (updated after training)
                alpha=self.manager.stage3_alpha
            )
            return scores
        else:
            # Use baseline scoring (no contrastive model)
            logger.info(f"Scoring {len(patterns)} patterns with baseline method (no contrastive model)")
            scores = score_patterns(
                patterns_with_embeddings,
                user_vec,
                model=None,  # No model = baseline
                alpha=self.manager.stage3_alpha
            )
            return scores

    def run(self) -> List[float]:
        """Run evaluation loop.
        
        Each round:
        1. Recommend patterns
        2. Simulate feedback
        3. Update model
        4. Predict ALL patterns
        5. Compute accuracy
        6. Store history
        
        Returns:
            List of accuracy scores for each round
        """
        logger.info("Starting evaluation loop...")
        
        # Step 1: Mine all patterns once and cache
        logger.info("Step 1: Mining all patterns...")
        self.all_patterns = self.manager.miner.mine_patterns(
            min_participation=0.6,
            max_pattern_size=5,
            priority="confidence"
        )
        logger.info(f"Mined {len(self.all_patterns)} patterns")
        
        # Step 2: Build ground truth
        logger.info("Step 2: Building ground truth...")
        self.ground_truth = self.simulator.build_ground_truth(self.all_patterns)
        logger.info(f"Ground truth: {sum(self.ground_truth.values())}/{len(self.ground_truth)} patterns liked")
        
        # Step 3: Iteration loop
        user_id = self.manager.stage3_user_id
        # Provide preference features for preference-weighted method (2025-style adaptive fusion)
        self.manager.memory.save_preference_features(
            user_id,
            list(self.simulator.liked),
            list(self.simulator.disliked) if self.simulator.disliked else []
        )
        query = "I want to find good co-location patterns"  # Generic query for evaluation
        
        # Initialize iteration session
        try:
            session_result = self.manager.start_iteration(query, iteration_rounds=self.rounds)
            session_id = session_result['session_id']
            logger.info(f"Initialized iteration session: {session_id}")
        except Exception as e:
            logger.warning(f"Failed to start iteration, using process_query instead: {e}")
            session_id = None

        # Step 2.5: Evaluate once before any feedback iteration (LLM-only if available)
        pre_scores = self._score_patterns_with_manager(
            self.all_patterns,
            user_id=user_id,
            interaction_round=0
        )
        if pre_scores is None:
            logger.warning("No pre-interaction scores available, using zero baseline")
            pre_scores = [0.0] * len(self.all_patterns)

        pre_threshold = compute_threshold(pre_scores)
        pre_predictions = {
            idx: (1 if score >= pre_threshold else 0)
            for idx, score in enumerate(pre_scores)
        }
        self.pre_interaction_metrics = {
            'accuracy': accuracy(self.ground_truth, pre_predictions),
            'precision': precision(self.ground_truth, pre_predictions),
            'recall': recall(self.ground_truth, pre_predictions),
            'f1': f1(self.ground_truth, pre_predictions),
            'threshold': pre_threshold
        }
        logger.info(
            "Pre-interaction metrics (round 0, no feedback) - "
            f"Accuracy: {self.pre_interaction_metrics['accuracy']:.4f}, "
            f"Precision: {self.pre_interaction_metrics['precision']:.4f}, "
            f"Recall: {self.pre_interaction_metrics['recall']:.4f}, "
            f"F1: {self.pre_interaction_metrics['f1']:.4f} "
            f"(threshold: {pre_threshold:.4f})"
        )
        
        for t in range(self.rounds):
            logger.info(f"=== Round {t+1}/{self.rounds} ===")
            
            # 3.1 Recommend patterns with diversity sampling
            if t == 0:
                # First round: use patterns from start_iteration
                if session_id:
                    all_ranked_patterns = session_result.get('patterns', [])
                    similarity_scores = session_result.get('similarity_scores', {})
                else:
                    # Fallback: use process_query
                    result = self.manager.process_query(query, iteration_rounds=None)
                    all_ranked_patterns = result.get('patterns', [])
                    similarity_scores = {}
            else:
                # Subsequent rounds: get patterns from next_iteration result
                if session_id and 'next_result' in locals() and next_result:
                    all_ranked_patterns = next_result.get('patterns', [])
                    similarity_scores = next_result.get('similarity_scores', {})
                else:
                    # Fallback: use process_query
                    result = self.manager.process_query(query, iteration_rounds=None)
                    all_ranked_patterns = result.get('patterns', [])
                    similarity_scores = {}
            
            if not all_ranked_patterns:
                logger.warning(f"No patterns available in round {t+1}, using top patterns from all_patterns")
                all_ranked_patterns = self.all_patterns[:20]  # Use top 20 for sampling
                similarity_scores = {}
            
            # Extract scores for ranked patterns
            scores = []
            for pattern in all_ranked_patterns:
                pattern_list = pattern.get('pattern', [])
                if isinstance(pattern_list, str):
                    pattern_list = pattern_list.split(',')
                pattern_key = ','.join(sorted(pattern_list))
                # Try to get score from similarity_scores, fallback to confidence or 0
                score = similarity_scores.get(pattern_key, pattern.get('confidence', 0.0))
                scores.append(score)
            
            # Use diversity sampler if we have enough patterns
            if len(all_ranked_patterns) >= self.top_k:
                recommended_patterns = DiversitySampler.sample_diverse_patterns(
                    all_ranked_patterns,
                    scores,
                    top_k=self.top_k,
                    strategy=self.sampling_strategy
                )
                logger.info(f"Selected {len(recommended_patterns)} patterns using {self.sampling_strategy} sampling strategy")
            else:
                # Fallback: use all available patterns if less than top_k
                recommended_patterns = all_ranked_patterns
                logger.info(f"Using all {len(recommended_patterns)} available patterns (less than top_k={self.top_k})")
            
            # 3.2 Simulate feedback
            feedback = {
                'positive': [],
                'negative': []
            }
            
            for pattern in recommended_patterns:
                label = self.simulator.label_pattern(pattern)
                pattern_str = ','.join(pattern.get('pattern', [])) if isinstance(pattern.get('pattern'), list) else str(pattern.get('pattern', ''))
                
                if label == 1:
                    feedback['positive'].append(pattern_str)
                else:
                    feedback['negative'].append(pattern_str)
            
            logger.info(f"Simulated feedback: {len(feedback['positive'])} positive, {len(feedback['negative'])} negative")
            
            # 3.3 Update model BEFORE scoring (so we score with updated model)
            if session_id and t < self.rounds - 1:  # Don't call next_iteration on last round
                try:
                    # Use next_iteration to update model and get next round patterns
                    next_result = self.manager.next_iteration(session_id, feedback)
                    logger.info(f"Model updated via next_iteration, got {len(next_result.get('patterns', []))} patterns for next round")
                except Exception as e:
                    logger.warning(f"Failed to update via next_iteration: {e}")
                    # If next_iteration fails, we can't continue with session
                    session_id = None
                    next_result = None
            elif session_id:
                # Last round: just update model without getting next round
                try:
                    self.manager.next_iteration(session_id, feedback)
                    logger.info(f"Model updated via next_iteration (last round)")
                except Exception as e:
                    logger.warning(f"Failed to update via next_iteration on last round: {e}")
            
            # 3.4 Predict ALL patterns using UPDATED model
            # Important: Score AFTER updating model, use t+1 to reflect that we've processed t rounds of feedback
            # Use manager's scoring method which properly handles contrastive learning
            # Force reload to ensure we use the latest model state
            if self.manager.trainer is not None and self.manager.trainer.model is not None:
                logger.debug(f"Round {t+1}: Trainer model available, using it for scoring")
            
            scores = self._score_patterns_with_manager(
                self.all_patterns,
                user_id=user_id,
                interaction_round=t + 1
            )
            
            if scores is None:
                logger.warning(f"No scores available in round {t+1}, using random baseline")
                scores = [0.0] * len(self.all_patterns)
            
            # Log score statistics for debugging
            if scores:
                score_array = np.array(scores)
                logger.info(f"Round {t+1} score stats: min={score_array.min():.4f}, max={score_array.max():.4f}, "
                           f"mean={score_array.mean():.4f}, median={np.median(score_array):.4f}, "
                           f"std={score_array.std():.4f}")
            
            # 3.5 Convert to labels using dynamic threshold
            threshold = compute_threshold(scores)
            predictions = {}
            for idx, score in enumerate(scores):
                predictions[idx] = 1 if score >= threshold else 0
            
            # Log prediction statistics
            positive_predictions = sum(1 for v in predictions.values() if v == 1)
            logger.info(f"Round {t+1} predictions: {positive_predictions}/{len(predictions)} predicted as positive (threshold: {threshold:.4f})")
            
            # 3.6 Compute metrics
            acc = accuracy(self.ground_truth, predictions)
            prec = precision(self.ground_truth, predictions)
            rec = recall(self.ground_truth, predictions)
            f1_score = f1(self.ground_truth, predictions)
            
            self.history.append(acc)
            self.precision_history.append(prec)
            self.recall_history.append(rec)
            self.f1_history.append(f1_score)
            
            logger.info(f"Round {t+1} metrics - Accuracy: {acc:.4f}, Precision: {prec:.4f}, "
                       f"Recall: {rec:.4f}, F1: {f1_score:.4f} (threshold: {threshold:.4f})")
        
        logger.info(f"Evaluation complete. Final metrics - Accuracy: {self.history[-1]:.4f}, "
                   f"Precision: {self.precision_history[-1]:.4f}, "
                   f"Recall: {self.recall_history[-1]:.4f}, "
                   f"F1: {self.f1_history[-1]:.4f}")
        return {
            'accuracy': self.history,
            'precision': self.precision_history,
            'recall': self.recall_history,
            'f1': self.f1_history,
            'pre_interaction': self.pre_interaction_metrics
        }

