"""Iteration manager for Stage4 iterative interaction engine."""

import logging
import numpy as np
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime

logger = logging.getLogger(__name__)


class IterationState:
    """State object for tracking iteration progress."""
    
    def __init__(self, user_id: str, query: str, max_rounds: int):
        """Initialize iteration state.
        
        Args:
            user_id: User identifier
            query: Original user query
            max_rounds: Maximum number of iteration rounds
        """
        self.user_id = user_id
        self.query = query
        self.max_rounds = max_rounds
        self.current_round = 0
        
        self.user_vector = None  # Current preference representation
        self.history = []  # Interaction history
        
    def add_round(self, round_data: Dict[str, Any]):
        """Add a round to history.
        
        Args:
            round_data: Data for this round
        """
        self.history.append({
            "round": self.current_round,
            "timestamp": datetime.now().isoformat(),
            **round_data
        })
    
    def is_complete(self) -> bool:
        """Check if iteration is complete.
        
        Returns:
            True if current_round >= max_rounds
        """
        return self.current_round >= self.max_rounds


class IterationManager:
    """Manager for iterative preference refinement (Stage4)."""
    
    def __init__(self, miner, learner, memory, embedder, llm_client=None, 
                 intent_encoder=None, intent_mapper=None, 
                 fusion_alpha: float = 0.6):
        """Initialize iteration manager.
        
        Args:
            miner: CoLocationMiner instance
            learner: PreferenceLearner instance
            memory: MemoryStore instance
            embedder: PatternEmbedder instance
            llm_client: Optional LLMClient for intent understanding
            intent_encoder: Optional IntentEncoder for Stage0
            intent_mapper: Optional IntentMapper for Stage0
            fusion_alpha: Weight for LLM intent vs feedback fusion (α)
        """
        self.miner = miner
        self.learner = learner
        self.memory = memory
        self.embedder = embedder
        self.llm_client = llm_client
        self.intent_encoder = intent_encoder
        self.intent_mapper = intent_mapper
        self.fusion_alpha = fusion_alpha
    
    def init_user_vector(self, query: str, user_id: str) -> Optional[np.ndarray]:
        """Initialize user vector from LLM intent understanding (Stage0).
        
        This method implements the complete flow:
        1. User ambiguous query → LLM intent understanding
        2. Extract structured intent (business, pattern_preference, risk factors)
        3. Validate patterns against dataset POI types
        4. Map patterns to embeddings (feature selection)
        5. Compute initial user vector u_llm
        
        Args:
            query: User query text
            user_id: User identifier
            
        Returns:
            Initial user preference vector, or None if Stage0 not available
        """
        if self.intent_encoder is None or self.intent_mapper is None:
            logger.info("Stage4: Stage0 not available, skipping intent initialization")
            return None
        
        try:
            logger.info(f"Stage4: Starting intent understanding for query: {query[:100]}...")
            
            # Step 1: Parse intent using LLM (intent understanding)
            intent_result = self.intent_encoder.parse(query)
            if not intent_result:
                logger.warning("Stage4: Failed to parse intent from query - LLM returned None")
                return None
            
            # Check for pattern_preference (correct key name)
            pattern_preference = intent_result.get("pattern_preference", [])
            if not pattern_preference or len(pattern_preference) == 0:
                logger.warning("Stage4: No pattern_preference found in parsed intent")
                logger.debug(f"Stage4: Intent result keys: {intent_result.keys()}")
                return None
            
            logger.info(f"Stage4: Successfully parsed intent - found {len(pattern_preference)} preferred patterns")
            logger.info(f"Stage4: Preferred patterns: {pattern_preference}")
            
            # Step 2: Map intent to vector (feature selection and vector computation)
            # This involves:
            # - Encoding each pattern to embedding (feature extraction)
            # - Computing mean vector as initial user preference
            logger.info("Stage4: Mapping intent patterns to user vector (feature selection)...")
            u_llm = self.intent_mapper.to_vector(intent_result)
            
            if u_llm is None:
                logger.warning("Stage4: Failed to convert intent to vector")
                return None
            
            logger.info(f"Stage4: Generated initial user vector u_llm (dim={len(u_llm)}, norm={np.linalg.norm(u_llm):.4f})")
            
            # Step 3: Save intent to memory
            self.memory.save_intent(user_id, u_llm, intent_result)
            logger.info("Stage4: Saved intent data to memory")
            
            logger.info(f"Stage4: Successfully initialized user vector from LLM intent (dim={len(u_llm)})")
            return u_llm
        except Exception as e:
            logger.warning(f"Stage4: Failed to initialize user vector from intent: {e}", exc_info=True)
            return None
    
    def run_one_round(self, state: IterationState, 
                     mining_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Run one iteration round.
        
        Algorithm:
        1. Mine candidate patterns
        2. Rank patterns by current user vector
        3. Return ranked patterns for user feedback
        
        Args:
            state: Current iteration state
            mining_params: Optional mining parameters (if None, uses defaults)
            
        Returns:
            Dictionary containing ranked patterns and metadata
        """
        state.current_round += 1
        logger.info(f"Stage4: Starting round {state.current_round}/{state.max_rounds}")
        
        # Step 1: Mine patterns
        if mining_params is None:
            # Use default parameters
            mining_params = {
                'min_participation': 0.6,
                'max_pattern_size': 5,
                'priority': 'confidence'
            }
        
        patterns = self.miner.mine_patterns(
            min_participation=mining_params.get('min_participation', 0.6),
            max_pattern_size=mining_params.get('max_pattern_size', 5),
            priority=mining_params.get('priority', 'confidence')
        )
        
        logger.info(f"Stage4: Round {state.current_round} mined {len(patterns)} patterns")
        
        # Step 2: Rank patterns by user preference
        top_patterns = patterns[:20]  # Take top 20 for ranking
        
        # Ensure patterns have embeddings
        for pattern in top_patterns:
            if 'embedding' not in pattern:
                pattern_vec = self.embedder.encode_pattern(pattern['pattern'])
                pattern['embedding'] = pattern_vec
        
        # Score patterns using current user vector
        scores = None
        if state.user_vector is not None:
            # Use current state vector directly
            scores = self.learner._score_with_vector(top_patterns, state.user_vector)
            logger.info(f"Stage4: Scored patterns using state user vector")
        else:
            # Fallback to learner's get_user_vector (includes fusion)
            scores = self.learner.score_patterns(
                top_patterns,
                user_id=state.user_id,
                interaction_round=state.current_round - 1
            )
            if scores is not None:
                logger.info(f"Stage4: Scored patterns using learner's fused vector")
        
        # Rank patterns by score
        ranked_patterns = top_patterns
        similarity_scores = {}
        
        if scores is not None:
            ranked = sorted(
                zip(top_patterns, scores),
                key=lambda x: x[1],
                reverse=True
            )
            ranked_patterns = [p for p, _ in ranked]
            
            # Create similarity score mapping
            for p, score in ranked:
                pattern_key = ','.join(sorted(p.get('pattern', [])))
                similarity_scores[pattern_key] = float(score)
            
            logger.info(f"Stage4: Re-ranked {len(ranked_patterns)} patterns")
        else:
            logger.info("Stage4: No user vector available, using original ranking")
        
        # Generate rules
        min_confidence = mining_params.get('min_confidence', 0.5)
        rules = self.miner.generate_rules(ranked_patterns, min_confidence=min_confidence)
        
        # Prepare round result
        round_result = {
            'round': state.current_round,
            'patterns': ranked_patterns,
            'rules': rules,
            'total_patterns': len(patterns),
            'total_rules': len(rules),
            'similarity_scores': similarity_scores,
            'user_vector_norm': float(np.linalg.norm(state.user_vector)) if state.user_vector is not None else 0.0
        }
        
        # Add to history
        state.add_round({
            'total_patterns': len(patterns),
            'total_rules': len(rules),
            'user_vector_norm': round_result['user_vector_norm']
        })
        
        return round_result
    
    def update_user_vector(self, state: IterationState, 
                          feedback: Dict[str, List[str]]) -> np.ndarray:
        """Update user vector based on feedback.
        
        Formula: u_t = α * u_llm + (1-α) * u_feedback
        
        Args:
            state: Current iteration state
            feedback: Dictionary with 'positive' and 'negative' pattern lists
            
        Returns:
            Updated user preference vector
        """
        # Store feedback in memory
        positive_patterns = feedback.get('positive', [])
        negative_patterns = feedback.get('negative', [])
        
        # Encode and store positive feedback
        for pattern_key in positive_patterns:
            # Parse pattern from key (comma-separated)
            pattern_list = pattern_key.split(',')
            pattern_vec = self.embedder.encode_pattern(pattern_list)
            if pattern_vec is not None:
                self.memory.add_positive(state.user_id, pattern_vec)
        
        # Encode and store negative feedback
        for pattern_key in negative_patterns:
            pattern_list = pattern_key.split(',')
            pattern_vec = self.embedder.encode_pattern(pattern_list)
            if pattern_vec is not None:
                self.memory.add_negative(state.user_id, pattern_vec)
        
        # Compute feedback vector
        u_feedback = self.learner.compute_feedback_vector(state.user_id)
        
        # Get LLM intent vector
        intent_data = self.memory.load_intent(state.user_id)
        u_llm = None
        if intent_data and intent_data.get("u_llm"):
            u_llm = np.array(intent_data["u_llm"])
        
        # Fuse vectors
        if u_llm is not None and u_feedback is not None:
            # u_t = α * u_llm + (1-α) * u_feedback
            u_t = self.fusion_alpha * u_llm + (1 - self.fusion_alpha) * u_feedback
            logger.info(f"Stage4: Fused user vector (alpha={self.fusion_alpha})")
        elif u_feedback is not None:
            u_t = u_feedback
            logger.info("Stage4: Using feedback vector only (no LLM intent)")
        elif u_llm is not None:
            u_t = u_llm
            logger.info("Stage4: Using LLM intent vector only (no feedback)")
        else:
            # No vector available, return zero vector
            dim = self.embedder.get_sentence_embedding_dimension()
            u_t = np.zeros(dim)
            logger.warning("Stage4: No user vector available, using zero vector")
        
        # Update state
        state.user_vector = u_t
        
        logger.info(f"Stage4: Updated user vector (norm={np.linalg.norm(u_t):.4f})")
        return u_t
    
    def run_one_iteration_step(self, state: IterationState,
                               mining_params: Optional[Dict[str, Any]] = None,
                               trainer=None) -> Dict[str, Any]:
        """Run one iteration step: mine, rank, return results for user feedback.
        
        This method is designed for interactive iteration where:
        1. Mine and rank patterns
        2. Return results for user to provide feedback
        3. After feedback, call update_and_train() to update vector and train model
        
        Args:
            state: Current iteration state
            mining_params: Optional mining parameters
            trainer: Optional PreferenceTrainer instance for training after feedback
            
        Returns:
            Dictionary containing round results (patterns, rules, etc.)
        """
        # Run one round
        round_result = self.run_one_round(state, mining_params)
        
        return {
            'round': state.current_round,
            'patterns': round_result['patterns'],
            'rules': round_result['rules'],
            'total_patterns': round_result['total_patterns'],
            'total_rules': round_result['total_rules'],
            'similarity_scores': round_result['similarity_scores'],
            'user_vector_norm': round_result['user_vector_norm'],
            'is_final': state.is_complete()
        }
    
    def update_and_train(self, state: IterationState,
                        feedback: Dict[str, List[str]],
                        trainer=None,
                        training_epochs: int = 5) -> Dict[str, Any]:
        """Update user vector based on feedback and immediately train model.
        
        Args:
            state: Current iteration state
            feedback: Dictionary with 'positive' and 'negative' pattern lists
            trainer: Optional PreferenceTrainer instance
            training_epochs: Number of epochs for training (default 5 for quick training)
            
        Returns:
            Dictionary with training results
        """
        # Update user vector
        updated_vector = self.update_user_vector(state, feedback)
        
        # Train model if trainer is provided
        training_result = None
        if trainer is not None:
            try:
                logger.info(f"Stage4: Training model after round {state.current_round} feedback")
                history = trainer.train(
                    epochs=training_epochs,
                    batch_size=32,
                    save_path=None  # Don't save intermediate models
                )
                training_result = {
                    'trained': True,
                    'epochs': training_epochs,
                    'loss_history': history.get('loss', [])
                }
                logger.info(f"Stage4: Model trained successfully (loss: {history.get('loss', [])[-1] if history.get('loss') else 'N/A'})")
            except Exception as e:
                logger.warning(f"Stage4: Training failed: {e}", exc_info=True)
                training_result = {
                    'trained': False,
                    'error': str(e)
                }
        else:
            training_result = {
                'trained': False,
                'reason': 'No trainer provided'
            }
        
        return {
            'user_vector_norm': float(np.linalg.norm(updated_vector)) if updated_vector is not None else 0.0,
            'training': training_result
        }
    
    def run(self, query: str, user_id: str, K: int,
            mining_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Run iterative preference refinement.
        
        Algorithm: Iterative Preference Refinement
        
        Input: Q (query), K (rounds)
        Output: P* (best patterns)
        
        Initialize u0
        for t=1..K:
          Generate Ct (candidate patterns)
          Rank Ct by ut
          Collect Ft (feedback)
          Update ut
        
        Return best
        
        Args:
            query: User query
            user_id: User identifier
            K: Number of iteration rounds
            mining_params: Optional mining parameters
            
        Returns:
            Dictionary containing final results and iteration history
        """
        logger.info(f"Stage4: Starting iterative refinement (K={K})")
        
        # Initialize state
        state = IterationState(user_id, query, K)
        
        # Initialize user vector from LLM intent
        state.user_vector = self.init_user_vector(query, user_id)
        
        # Iteration history
        iteration_history = []
        
        # Main iteration loop
        while not state.is_complete():
            # Run one round
            round_result = self.run_one_round(state, mining_params)
            iteration_history.append(round_result)
            
            logger.info(f"Stage4: Round {state.current_round}/{K} completed")
        
        # Final ranking
        final_result = self._final_rank(state, mining_params)
        
        return {
            'query': query,
            'total_rounds': K,
            'iteration_history': iteration_history,
            'final_patterns': final_result['patterns'],
            'final_rules': final_result['rules'],
            'final_similarity_scores': final_result.get('similarity_scores', {}),
            'user_vector_norm': float(np.linalg.norm(state.user_vector)) if state.user_vector is not None else 0.0
        }
    
    def _final_rank(self, state: IterationState,
                   mining_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Generate final ranking after all iterations.
        
        Args:
            state: Final iteration state
            mining_params: Optional mining parameters
            
        Returns:
            Dictionary containing final ranked patterns and rules
        """
        logger.info("Stage4: Generating final ranking")
        
        if mining_params is None:
            mining_params = {
                'min_participation': 0.6,
                'max_pattern_size': 5,
                'priority': 'confidence'
            }
        
        # Mine final patterns
        patterns = self.miner.mine_patterns(
            min_participation=mining_params.get('min_participation', 0.6),
            max_pattern_size=mining_params.get('max_pattern_size', 5),
            priority=mining_params.get('priority', 'confidence')
        )
        
        # Rank with final user vector
        top_patterns = patterns[:20]
        for pattern in top_patterns:
            if 'embedding' not in pattern:
                pattern_vec = self.embedder.encode_pattern(pattern['pattern'])
                pattern['embedding'] = pattern_vec
        
        scores = None
        if state.user_vector is not None:
            scores = self.learner._score_with_vector(top_patterns, state.user_vector)
        else:
            scores = self.learner.score_patterns(
                top_patterns,
                user_id=state.user_id,
                interaction_round=state.current_round
            )
        
        ranked_patterns = top_patterns
        similarity_scores = {}
        
        if scores is not None:
            ranked = sorted(
                zip(top_patterns, scores),
                key=lambda x: x[1],
                reverse=True
            )
            ranked_patterns = [p for p, _ in ranked]
            
            for p, score in ranked:
                pattern_key = ','.join(sorted(p.get('pattern', [])))
                similarity_scores[pattern_key] = float(score)
        
        min_confidence = mining_params.get('min_confidence', 0.5)
        rules = self.miner.generate_rules(ranked_patterns, min_confidence=min_confidence)
        
        return {
            'patterns': ranked_patterns,
            'rules': rules,
            'similarity_scores': similarity_scores
        }

