"""Pipeline manager for coordinating the entire system."""

import logging
import os
import numpy as np
import uuid
from typing import Dict, Any, Optional, List

import yaml

# 使用绝对导入（项目根目录已在 sys.path 中）
from llm import LLMClient, PromptTemplate
from preference import PreferenceParser
from preference.scorer import build_user_vector, score_patterns as score_with_preference
from core import CoLocationMiner
from memory.store import MemoryStore
from learning.learner import PreferenceLearner
from learning.embedder import PatternEmbedder

# Stage0: Intent understanding components
try:
    from llm.intent_encoder import IntentEncoder
    from llm.intent_mapper import IntentMapper
    STAGE0_AVAILABLE = True
except ImportError:
    STAGE0_AVAILABLE = False
    logging.warning("Stage0 intent understanding components not available")

# Stage3: Optional contrastive learning components
try:
    from embedding.encoder import PreferenceEncoder
    from learning.trainer import PreferenceTrainer
    STAGE3_CONTRASTIVE_AVAILABLE = True
except ImportError:
    STAGE3_CONTRASTIVE_AVAILABLE = False
    logging.warning("Stage3 contrastive learning components not available (PyTorch required)")

# Stage4: Iterative interaction components
try:
    from controller.iteration_manager import IterationManager
    STAGE4_AVAILABLE = True
except ImportError:
    STAGE4_AVAILABLE = False
    logging.warning("Stage4 iterative interaction components not available")

logger = logging.getLogger(__name__)


class PipelineManager:
    """Manager for coordinating the entire pipeline."""
    
    def __init__(self, config_path: str = "config/config.yaml"):
        """Initialize pipeline manager.
        
        Args:
            config_path: Path to configuration file
        """
        with open(config_path, 'r', encoding='utf-8') as f:
            self.config = yaml.safe_load(f)
        
        # 初始化各个组件
        logger.info("Initializing pipeline components...")
        self.llm_client = LLMClient(config_path)
        self.preference_parser = PreferenceParser(
            default_params=self.config.get('mining', {})
        )
        self.miner = CoLocationMiner(
            data_path=self.config['data']['data_path'],
            distance_threshold=16.0, # 可以从配置读取
            use_gpu=True
              
        )
        # Stage0: 加载配置
        stage0_config = self.config.get('stage0', {})
        self.stage0_enabled = stage0_config.get('enabled', True) and STAGE0_AVAILABLE
        self.stage0_decay_lambda = stage0_config.get('decay_lambda', 0.05)
        self.stage0_min_confidence = stage0_config.get('min_confidence', 0.6)
        
        # Stage3: 初始化记忆和学习模块
        self.memory = MemoryStore()
        decay_lambda = self.stage0_decay_lambda if self.stage0_enabled else 0.05
        self.learner = PreferenceLearner(
            decay_lambda=decay_lambda,
            use_preference_weighted=self.config.get('stage3', {}).get('use_preference_weighted', False)
        )
        self.embedder = PatternEmbedder()
        
        # Stage0: 初始化意图理解组件
        self.intent_encoder = None
        self.intent_mapper = None
        if self.stage0_enabled:
            try:
                # 获取数据集中实际的 POI 类型
                available_poi_types = list(self.miner.feature_counts.keys())
                self.intent_encoder = IntentEncoder(self.llm_client, available_poi_types=available_poi_types)
                self.intent_mapper = IntentMapper(self.embedder)
                logger.info(f"Stage0: Intent understanding components initialized with {len(available_poi_types)} POI types")
            except Exception as e:
                logger.warning(f"Stage0: Failed to initialize intent components: {e}", exc_info=True)
                self.stage0_enabled = False
        
        # Stage3: 加载配置
        stage3_config = self.config.get('stage3', {})
        self.stage3_enabled = stage3_config.get('enabled', True)
        self.stage3_use_contrastive = stage3_config.get('use_contrastive', False) and STAGE3_CONTRASTIVE_AVAILABLE
        self.stage3_use_preference_weighted = stage3_config.get('use_preference_weighted', False)
        self.stage3_retrain_interval = stage3_config.get('retrain_interval', 20)
        self.stage3_alpha = stage3_config.get('alpha', 0.7)
        self.stage3_user_id = stage3_config.get('user_id', 'user_001')
        self.stage3_model_path = stage3_config.get('model_save_path', './models/preference_encoder.pth')
        
        # 跟踪交互轮数（用于 Stage0 衰减）
        self.interaction_count = 0
        
        # Stage4: 存储迭代会话状态（用于交互式迭代）
        self.iteration_sessions = {}  # {session_id: IterationState}
        
        # Stage4: 加载配置
        stage4_config = self.config.get('stage4', {})
        self.stage4_enabled = stage4_config.get('enabled', False) and STAGE4_AVAILABLE
        self.stage4_default_rounds = stage4_config.get('default_rounds', 3)
        self.stage4_max_rounds = stage4_config.get('max_rounds', 5)
        self.stage4_fusion_alpha = stage4_config.get('fusion_alpha', 0.6)
        
        # Stage4: 初始化迭代管理器
        self.iteration_manager = None
        if self.stage4_enabled:
            try:
                self.iteration_manager = IterationManager(
                    miner=self.miner,
                    learner=self.learner,
                    memory=self.memory,
                    embedder=self.embedder,
                    llm_client=self.llm_client,
                    intent_encoder=self.intent_encoder,
                    intent_mapper=self.intent_mapper,
                    fusion_alpha=self.stage4_fusion_alpha
                )
                logger.info(f"Stage4: IterationManager initialized (default_rounds={self.stage4_default_rounds}, max_rounds={self.stage4_max_rounds})")
            except Exception as e:
                logger.warning(f"Stage4: Failed to initialize IterationManager: {e}", exc_info=True)
                self.stage4_enabled = False
        
        # Stage3: 初始化对比学习组件（如果启用）
        self.preference_model = None
        self.trainer = None
        if self.stage3_use_contrastive:
            try:
                # 初始化训练器（无论是否有已训练的模型）
                input_dim = self.embedder.model.get_sentence_embedding_dimension()
                hidden_dim = stage3_config.get('hidden_dim', 256)
                
                # 如果已有模型文件，先加载模型
                if os.path.exists(self.stage3_model_path):
                    self.preference_model = PreferenceEncoder(
                        input_dim=input_dim,
                        hidden_dim=hidden_dim
                    )
                    self.trainer = PreferenceTrainer(
                        memory_store=self.memory,
                        embedder=self.embedder,
                        model=self.preference_model,
                        margin=stage3_config.get('margin', 0.3),
                        learning_rate=stage3_config.get('learning_rate', 0.001),
                        use_cosine_loss=stage3_config.get('use_cosine_loss', True),
                        user_id=self.stage3_user_id
                    )
                    self.trainer.load_model(self.stage3_model_path)
                    logger.info(f"Stage3: Loaded trained preference encoder from {self.stage3_model_path}")
                else:
                    # 没有已训练模型，创建新的训练器（模型会在训练时创建）
                    self.trainer = PreferenceTrainer(
                        memory_store=self.memory,
                        embedder=self.embedder,
                        model=None,  # 让训练器自己创建新模型
                        input_dim=input_dim,
                        hidden_dim=hidden_dim,
                        margin=stage3_config.get('margin', 0.3),
                        learning_rate=stage3_config.get('learning_rate', 0.001),
                        use_cosine_loss=stage3_config.get('use_cosine_loss', True),
                        user_id=self.stage3_user_id
                    )
                    logger.info("Stage3: Trainer initialized (no pre-trained model, will create new model on first training)")
            except Exception as e:
                logger.warning(f"Stage3: Failed to initialize contrastive learning: {e}", exc_info=True)
                self.stage3_use_contrastive = False
                self.trainer = None
        
        logger.info(f"Pipeline initialized successfully (Stage0: enabled={self.stage0_enabled}, "
                   f"Stage3: enabled={self.stage3_enabled}, contrastive={self.stage3_use_contrastive}, "
                   f"Stage4: enabled={self.stage4_enabled})")
    
    def is_ambiguous(self, text: str) -> bool:
        """Check if user query is ambiguous (requires intent understanding).
        
        Args:
            text: User query text
            
        Returns:
            True if query is ambiguous, False otherwise
        """
        keywords = ["推荐", "适合", "开店", "选址", "发展", "建议", "应该", "想要", "希望"]
        return any(k in text for k in keywords)
    
    def process_query(self, user_input: str, iteration_rounds: Optional[int] = None) -> Dict[str, Any]:
        """Process user query through the entire pipeline.
        
        Args:
            user_input: User's natural language query
            iteration_rounds: Optional number of iteration rounds (for Stage4). 
                            If None, uses default from config or single-shot mode.
            
        Returns:
            Dictionary containing extracted parameters, patterns, and explanation
        """
        logger.info(f"Processing query: {user_input}")
        
        # Stage4: Use iterative interaction if enabled
        if self.stage4_enabled and self.iteration_manager is not None:
            if iteration_rounds is None:
                iteration_rounds = self.stage4_default_rounds
            elif iteration_rounds > self.stage4_max_rounds:
                iteration_rounds = self.stage4_max_rounds
            
            logger.info(f"Stage4: Using iterative interaction mode (K={iteration_rounds})")
            
            # Extract mining parameters from query (for iterative mining)
            preference_prompt = PromptTemplate.PREFERENCE_EXTRACTION_PROMPT.format(
                user_input=user_input
            )
            system_prompt = PromptTemplate.SYSTEM_PROMPT
            llm_response = self.llm_client.generate(preference_prompt, system=system_prompt)
            params = self.preference_parser.parse_preference(llm_response)
            
            # Run iterative refinement
            iteration_result = self.iteration_manager.run(
                query=user_input,
                user_id=self.stage3_user_id,
                K=iteration_rounds,
                mining_params=params
            )
            
            # Generate explanation
            explanation = self._generate_explanation(
                params,
                iteration_result['final_patterns'],
                iteration_result['final_rules']
            )
            
            # Update interaction count
            self.interaction_count += iteration_rounds
            
            return {
                'user_input': user_input,
                'extracted_parameters': params,
                'patterns': iteration_result['final_patterns'],
                'rules': iteration_result['final_rules'],
                'explanation': explanation,
                'pattern_count': len(iteration_result['final_patterns']),
                'rule_count': len(iteration_result['final_rules']),
                're_ranked': True,  # Stage4 always uses ranking
                'similarity_scores': iteration_result.get('final_similarity_scores', {}),
                'intent_used': True,  # Stage4 uses intent understanding
                'intent_data': None,  # Will be loaded from memory if needed
                'total_rounds': iteration_result['total_rounds'],
                'iteration_history': iteration_result['iteration_history']
            }
        
        # Original single-shot mode (Stage3)
        
        # Stage0: Intent understanding for ambiguous queries
        intent_used = False
        intent_data = None
        if self.stage0_enabled and self.is_ambiguous(user_input):
            logger.info("Stage0: Detected ambiguous query, performing intent understanding...")
            try:
                # Parse intent using LLM
                intent_json = self.intent_encoder.parse(user_input)
                if intent_json:
                    # Convert intent to user vector
                    u_llm = self.intent_mapper.to_vector(intent_json)
                    if u_llm is not None:
                        # Save intent to memory
                        self.memory.save_intent(self.stage3_user_id, u_llm, intent_json)
                        intent_used = True
                        intent_data = intent_json
                        logger.info(f"Stage0: Generated initial user vector from intent (dim: {u_llm.shape})")
                    else:
                        logger.warning("Stage0: Failed to convert intent to vector")
                else:
                    logger.warning("Stage0: Failed to parse intent from query")
            except Exception as e:
                logger.error(f"Stage0: Error in intent understanding: {e}", exc_info=True)
        
        # Step 1: Extract preferences using LLM
        logger.info("Step 1: Extracting preferences from user input...")
        preference_prompt = PromptTemplate.PREFERENCE_EXTRACTION_PROMPT.format(
            user_input=user_input
        )
        
        system_prompt = PromptTemplate.SYSTEM_PROMPT
        llm_response = self.llm_client.generate(preference_prompt, system=system_prompt)
        logger.debug(f"LLM response: {llm_response}")
        print('llm_response',llm_response)
        # Step 2: Parse preferences
        logger.info("Step 2: Parsing preferences...")
        params = self.preference_parser.parse_preference(llm_response)
        logger.info(f"Extracted parameters: {params}")
        
        # Step 3: Mine patterns
        logger.info("Step 3: Mining co-location patterns...")
        patterns = self.miner.mine_patterns(
            min_participation=params['min_participation'],
            max_pattern_size=params['max_pattern_size'],
            priority=params['priority']
        )
        logger.info(f"Found {len(patterns)} patterns")
        
        # Step 4: Generate association rules from patterns
        logger.info("Step 4: Generating association rules...")
        min_confidence = params.get('min_confidence', 0.5)  # 默认最小置信度
        rules = self.miner.generate_rules(patterns, min_confidence=min_confidence)
        logger.info(f"Generated {len(rules)} rules")
        
        # Stage3: 使用学习结果重排模式（关键闭环）
        re_ranked = False
        similarity_scores = None
        
        if self.stage3_enabled:
            top_patterns = patterns[:20]  # 取前20个模式用于重排序
            
            # 为模式计算embedding（如果还没有）
            for pattern in top_patterns:
                if 'embedding' not in pattern:
                    # pattern['pattern'] 已经是 List[str]，直接传入
                    pattern_vec = self.embedder.encode_pattern(pattern['pattern'])
                    pattern['embedding'] = pattern_vec
            
            # 尝试获取用户向量（包括意图向量和反馈向量）
            # 优先使用 learner 的融合方法（支持 Stage0 意图向量）
            scores = None
            
            # 方法1: 使用 learner 的融合向量（支持 Stage0 + Stage3）
            if self.stage0_enabled or self.stage3_enabled:
                scores = self.learner.score_patterns(
                    top_patterns,
                    user_id=self.stage3_user_id,
                    interaction_round=self.interaction_count
                )
                if scores is not None:
                    logger.info("Stage3: Using fused user vector (intent + feedback) for scoring")
            
            # 方法2: 如果 learner 返回 None，尝试使用传统方法
            if scores is None:
                user_profile = self.memory.get_user_profile(self.stage3_user_id)
                if user_profile:
                    user_vec = build_user_vector(user_profile, dim=self.embedder.model.get_sentence_embedding_dimension())
                    
                    # 使用新的打分器（支持对比学习模型）
                    if self.stage3_use_contrastive and self.preference_model is not None:
                        scores = score_with_preference(
                            top_patterns, 
                            user_vec, 
                            model=self.preference_model,
                            alpha=self.stage3_alpha
                        )
                        logger.info("Stage3: Using contrastive learning model for scoring")
                    else:
                        # 使用baseline scorer
                        scores = score_with_preference(
                            top_patterns,
                            user_vec,
                            model=None,
                            alpha=self.stage3_alpha
                        )
                        logger.info("Stage3: Using baseline scorer for scoring")
            
            # 方法3: 如果仍然没有分数，检查是否有意图向量（Stage0）
            if scores is None and self.stage0_enabled:
                intent_data = self.memory.load_intent(self.stage3_user_id)
                if intent_data and intent_data.get("u_llm"):
                    u_llm = np.array(intent_data["u_llm"])
                    # 使用意图向量直接计算相似度
                    scores = self.learner._score_with_vector(top_patterns, u_llm)
                    logger.info("Stage0: Using intent vector for scoring (no feedback yet)")
            
            # 如果有分数，进行重排序
            if scores is not None:
                logger.info("Stage3: Re-ranking patterns based on user preferences...")
                # 将模式和分数配对，按分数降序排序
                ranked = sorted(
                    zip(top_patterns, scores),
                    key=lambda x: x[1],
                    reverse=True
                )
                # 提取重排序后的模式和分数
                top_patterns = [p for p, _ in ranked]
                # 创建模式到相似度分数的映射
                similarity_scores = {}
                for p, score in ranked:
                    pattern_key = ','.join(sorted(p.get('pattern', [])))
                    similarity_scores[pattern_key] = score
                # 将重排序后的模式放回原列表
                patterns = top_patterns + patterns[20:]
                re_ranked = True
                logger.info(f"Stage3: Re-ranked {len(top_patterns)} patterns based on user preferences")
            else:
                logger.info("Stage3: No user vector available (no intent and no feedback), using original ranking")
        
        # Step 5: Generate explanation using LLM
        logger.info("Step 5: Generating explanation...")
        explanation = self._generate_explanation(params, patterns, rules)
        
        # 增加交互计数
        self.interaction_count += 1
        
        return {
            'user_input': user_input,
            'extracted_parameters': params,
            'patterns': patterns,
            'rules': rules,
            'explanation': explanation,
            'pattern_count': len(patterns),
            'rule_count': len(rules),
            're_ranked': re_ranked,  # Stage3: 标记是否已重排序
            'similarity_scores': similarity_scores,  # Stage3: 相似度分数
            'intent_used': intent_used,  # Stage0: 是否使用了意图理解
            'intent_data': intent_data  # Stage0: 意图数据
        }
    
    def _generate_explanation(self, params: Dict[str, Any], 
                              patterns: list, rules: list = None) -> str:
        """Generate explanation for mining results.
        
        Args:
            params: Mining parameters
            patterns: Discovered patterns
            
        Returns:
            Explanation text
        """
        # 格式化模式信息
        pattern_info = []
        for i, pattern in enumerate(patterns[:10], 1):  # 只显示前10个
            pattern_str = "{" + ", ".join(pattern['pattern']) + "}"
            pattern_info.append(
                f"{i}. {pattern_str} (参与率: {pattern['participation_index']:.3f}, "
                f"置信度: {pattern['confidence']:.3f}, 大小: {pattern['size']})"
            )
        
        patterns_text = "\n".join(pattern_info) if pattern_info else "未发现满足条件的模式"
        
        # 构建解释提示
        explanation_prompt = PromptTemplate.EXPLANATION_GENERATION_PROMPT.format(
            min_participation=params['min_participation'],
            max_pattern_size=params['max_pattern_size'],
            priority=params['priority'],
            patterns=patterns_text
        )
        
        system_prompt = PromptTemplate.SYSTEM_PROMPT
        explanation = self.llm_client.generate(explanation_prompt, system=system_prompt)
        
        return explanation
    
    def train_model(self) -> bool:
        """手动触发模型训练。
        
        Returns:
            True if training succeeded, False otherwise
        """
        if not self.stage3_enabled:
            print("\n❌ Stage3 未启用")
            return False
        
        if not self.stage3_use_contrastive:
            print("\n❌ 对比学习未启用")
            return False
        
        if self.trainer is None:
            print("\n❌ 训练器未初始化（可能需要安装 PyTorch）")
            return False
        
        user_profile = self.memory.get_user_profile(self.stage3_user_id)
        if not user_profile:
            print("\n❌ 未找到用户反馈数据，请先进行查询并提供反馈")
            return False
        
        positive_count = len(user_profile.get('positive', []))
        negative_count = len(user_profile.get('negative', []))
        total_feedback = positive_count + negative_count
        
        if total_feedback == 0:
            print("\n❌ 没有反馈数据，无法训练")
            return False
        
        if positive_count == 0 or negative_count == 0:
            print(f"\n⚠️  警告: 需要同时有正反馈和负反馈才能训练")
            print(f"  当前: 正反馈 {positive_count} 个, 负反馈 {negative_count} 个")
            print(f"  建议: 至少各有 1 个反馈才能生成有效的 triplet 数据")
            response = input("  是否继续训练? (y/n): ").strip().lower()
            if response != 'y':
                return False
        
        print(f"\n{'='*60}")
        print("【Stage3 训练开始】")
        print(f"  反馈总数: {total_feedback} (正: {positive_count}, 负: {negative_count})")
        print(f"{'='*60}\n")
        
        try:
            stage3_config = self.config.get('stage3', {})
            epochs = stage3_config.get('epochs', 10)
            
            history = self.trainer.train(
                epochs=epochs,
                batch_size=stage3_config.get('batch_size', 32),
                save_path=self.stage3_model_path
            )
            
            final_loss = history['loss'][-1] if history['loss'] else 0.0
            initial_loss = history['loss'][0] if history['loss'] else 0.0
            
            # 保存训练历史
            self._save_training_history(total_feedback, epochs, history)
            
            logger.info(f"Stage3: Training completed, final loss: {final_loss:.4f}")
            print(f"\n{'='*60}")
            print("【Stage3 训练完成】")
            print(f"  训练轮数: {epochs}")
            print(f"  初始Loss: {initial_loss:.4f}")
            print(f"  最终Loss: {final_loss:.4f}")
            print(f"  Loss下降: {initial_loss - final_loss:.4f}")
            print(f"  模型已保存: {self.stage3_model_path}")
            print(f"{'='*60}\n")
            
            # 重新加载模型
            self.preference_model = self.trainer.model
            return True
            
        except Exception as e:
            logger.error(f"Stage3: Training failed: {e}", exc_info=True)
            print(f"\n❌ 训练失败: {e}\n")
            return False
    
    def run_pipeline(self):
        """Run interactive pipeline (CLI mode)."""
        print("=" * 60)
        print("交互式 Co-location 模式挖掘系统 (MVP)")
        print("=" * 60)
        print("输入 'quit' 或 'exit' 退出")
        print("输入 'train' 手动触发模型训练")
        print()
        
        while True:
            try:
                user_input = input("请输入查询: ").strip()
                
                if user_input.lower() in ['quit', 'exit', 'q']:
                    print("再见！")
                    break
                
                if user_input.lower() in ['train', 'training']:
                    # 手动触发训练
                    self.train_model()
                    continue
                
                if not user_input:
                    continue
                
                # 处理查询
                result = self.process_query(user_input)
                
                # 显示结果
                self._display_results(result)
                
                # Stage3: 采集用户反馈并写入 Memory
                if result['patterns'] and self.stage3_enabled:
                    feedback = self.collect_feedback(result['patterns'])
                    
                    # 保存会话历史（需要清理不能序列化的字段）
                    # 创建模式的副本，移除 embedding 字段（numpy array 不能序列化为 JSON）
                    patterns_to_save = []
                    for pattern in result['patterns'][:20]:
                        pattern_copy = pattern.copy()
                        # 移除 embedding 字段（如果需要可以保存为列表，但通常不需要）
                        if 'embedding' in pattern_copy:
                            del pattern_copy['embedding']
                        patterns_to_save.append(pattern_copy)
                    
                    self.memory.add_session(
                        query=user_input,
                        params=result['extracted_parameters'],
                        patterns=patterns_to_save,
                        feedback=feedback
                    )
                    
                    # 保存向量到用户记忆
                    liked_patterns = feedback.get('like', [])
                    disliked_patterns = feedback.get('dislike', [])
                    
                    for pattern in liked_patterns:
                        if isinstance(pattern, dict):
                            pattern_list = pattern.get('pattern', [])
                        else:
                            pattern_list = pattern if isinstance(pattern, list) else []
                        
                        if pattern_list:
                            # pattern_list 已经是 List[str]，直接传入
                            pattern_vec = self.embedder.encode_pattern(pattern_list)
                            self.memory.add_positive(self.stage3_user_id, pattern_vec)
                    
                    for pattern in disliked_patterns:
                        if isinstance(pattern, dict):
                            pattern_list = pattern.get('pattern', [])
                        else:
                            pattern_list = pattern if isinstance(pattern, list) else []
                        
                        if pattern_list:
                            # pattern_list 已经是 List[str]，直接传入
                            pattern_vec = self.embedder.encode_pattern(pattern_list)
                            self.memory.add_negative(self.stage3_user_id, pattern_vec)
                    
                    logger.info(f"Stage3: Feedback saved to memory ({len(liked_patterns)} liked, "
                               f"{len(disliked_patterns)} disliked)")
                    
                    # 不再自动触发训练，改为手动触发
                
            except KeyboardInterrupt:
                print("\n\n再见！")
                break
            except Exception as e:
                logger.error(f"Error processing query: {e}", exc_info=True)
                print(f"错误: {e}")
    
    def _display_results(self, result: Dict[str, Any]):
        """Display results in a formatted way.
        
        Args:
            result: Result dictionary from process_query
        """
        print("\n" + "=" * 60)
        print("系统输出:")
        print("=" * 60)
        
        # 显示提取的参数
        print("\n【提取的参数】")
        params = result['extracted_parameters']
        print(f"  最小参与率: {params['min_participation']}")
        print(f"  最大模式大小: {params['max_pattern_size']}")
        print(f"  优先级: {params['priority']}")
        
        # Stage4: 显示迭代历史（如果有多轮交互）
        if result.get('total_rounds', 0) > 1:
            print("\n【迭代历史】")
            print(f"  总轮数: {result['total_rounds']}")
            iteration_history = result.get('iteration_history', [])
            if iteration_history:
                for i, round_info in enumerate(iteration_history, 1):
                    print(f"  轮次 {i}: 发现 {round_info.get('total_patterns', 0)} 个模式, "
                          f"{round_info.get('total_rules', 0)} 个规则")
        
        # Stage3: 显示排序优化状态
        if result.get('re_ranked', False):
            print("\n【排序优化】")
            print("  ✓ 已根据您的历史偏好对模式进行个性化排序")
            print("  ✓ 相似度越高的模式越符合您的兴趣")
        
        # 显示发现的模式
        patterns = result['patterns']
        re_ranked = result.get('re_ranked', False)
        similarity_scores = result.get('similarity_scores', {})
        
        if re_ranked:
            print(f"\n【发现的模式】 (共 {result['pattern_count']} 个) [已根据您的历史偏好优化排序]")
        else:
            print(f"\n【发现的模式】 (共 {result['pattern_count']} 个)")
        
        if patterns:
            for i, pattern in enumerate(patterns[:10], 1):  # 只显示前10个
                pattern_str = "{" + ", ".join(pattern['pattern']) + "}"
                
                # 构建显示信息
                info_parts = [
                    f"参与率: {pattern['participation_index']:.3f}",
                    f"置信度: {pattern['confidence']:.3f}",
                    f"大小: {pattern['size']}",
                    f"实例数: {pattern['instance_count']}"
                ]
                
                # Stage3: 如果有相似度分数，显示相似度
                if re_ranked and similarity_scores:
                    pattern_key = ','.join(sorted(pattern['pattern']))
                    if pattern_key in similarity_scores:
                        similarity = similarity_scores[pattern_key]
                        info_parts.append(f"相似度: {similarity:.3f} ⭐")
                
                print(f"  {i}. {pattern_str}")
                print(f"     {', '.join(info_parts)}")
        else:
            print("  未发现满足条件的模式")
        
        # 显示关联规则
        rules = result.get('rules', [])
        if rules:
            print(f"\n【关联规则】 (共 {result.get('rule_count', 0)} 个)")
            for i, rule in enumerate(rules[:10], 1):  # 只显示前10个
                antecedent_str = "{" + ", ".join(rule['antecedent']) + "}"
                consequent_str = "{" + ", ".join(rule['consequent']) + "}"
                print(f"  {i}. {antecedent_str} → {consequent_str}")
                print(f"     置信度: {rule['confidence']:.3f}, "
                      f"支持度: {rule['support']:.4f}, "
                      f"提升度: {rule['lift']:.3f}")
        else:
            print(f"\n【关联规则】 (共 0 个)")
            print("  未生成满足条件的关联规则")
        
        # 显示解释
        print("\n【解释】")
        print(f"  {result['explanation']}")
        print("=" * 60 + "\n")
    
    def collect_feedback(self, patterns: list) -> Dict[str, Any]:
        """Collect user feedback on patterns.
        
        Args:
            patterns: List of patterns shown to user
            
        Returns:
            Dictionary containing like and dislike feedback
        """
        print("\n" + "=" * 60)
        print("【用户反馈采集】")
        print("=" * 60)
        print("请选择感兴趣的模式编号（逗号分隔，如 1,3,5），直接回车跳过：")
        like_input = input("like> ").strip()
        
        print("请选择不感兴趣的模式编号（逗号分隔），直接回车跳过：")
        dislike_input = input("dislike> ").strip()
        
        def parse_indices(text: str) -> list:
            """Parse comma-separated indices and return corresponding patterns."""
            if not text.strip():
                return []
            try:
                indices = [int(i.strip()) - 1 for i in text.split(",")]
                # 过滤有效索引并获取对应模式
                valid_patterns = []
                for idx in indices:
                    if 0 <= idx < len(patterns):
                        valid_patterns.append(patterns[idx])
                return valid_patterns
            except ValueError:
                logger.warning(f"Invalid input format: {text}")
                return []
        
        feedback = {
            "like": parse_indices(like_input),
            "dislike": parse_indices(dislike_input)
        }
        
        like_count = len(feedback["like"])
        dislike_count = len(feedback["dislike"])
        print(f"\n已记录反馈: {like_count} 个喜欢, {dislike_count} 个不喜欢")
        
        return feedback
    
    def _save_training_history(self, feedback_count: int, epochs: int, history: Dict[str, Any]):
        """保存训练历史到文件。
        
        Args:
            feedback_count: 触发训练时的反馈总数
            epochs: 训练轮数
            history: 训练历史（包含loss列表）
        """
        import json
        from datetime import datetime
        
        history_path = "logs/training_history.json"
        os.makedirs(os.path.dirname(history_path), exist_ok=True)
        
        # 加载现有历史
        if os.path.exists(history_path):
            try:
                with open(history_path, 'r', encoding='utf-8') as f:
                    history_data = json.load(f)
            except:
                history_data = {"trainings": []}
        else:
            history_data = {"trainings": []}
        
        # 添加新训练记录
        training_record = {
            "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "feedback_count": feedback_count,
            "epochs": epochs,
            "final_loss": history['loss'][-1] if history['loss'] else None,
            "initial_loss": history['loss'][0] if history['loss'] else None,
            "loss_history": history['loss']
        }
        
        history_data["trainings"].append(training_record)
        
        # 保存
        with open(history_path, 'w', encoding='utf-8') as f:
            json.dump(history_data, f, indent=2, ensure_ascii=False)
    
    def start_iteration(self, query: str, iteration_rounds: int) -> Dict[str, Any]:
        """Start an interactive iteration session (Stage4).
        
        Args:
            query: User query
            iteration_rounds: Number of iteration rounds
            
        Returns:
            Dictionary containing session_id and first round results
        """
        if not self.stage4_enabled or self.iteration_manager is None:
            raise ValueError("Stage4 is not enabled")
        
        if iteration_rounds > self.stage4_max_rounds:
            iteration_rounds = self.stage4_max_rounds
        
        # Extract mining parameters
        preference_prompt = PromptTemplate.PREFERENCE_EXTRACTION_PROMPT.format(
            user_input=query
        )
        system_prompt = PromptTemplate.SYSTEM_PROMPT
        llm_response = self.llm_client.generate(preference_prompt, system=system_prompt)
        params = self.preference_parser.parse_preference(llm_response)
        
        # Create iteration state
        from controller.iteration_manager import IterationState
        state = IterationState(
            user_id=self.stage3_user_id,
            query=query,
            max_rounds=iteration_rounds
        )
        
        # Initialize user vector from LLM intent
        state.user_vector = self.iteration_manager.init_user_vector(query, self.stage3_user_id)
        
        # Generate session ID
        session_id = str(uuid.uuid4())
        self.iteration_sessions[session_id] = {
            'state': state,
            'params': params,
            'query': query
        }
        
        # Run first round
        round_result = self.iteration_manager.run_one_iteration_step(state, params, trainer=self.trainer)
        
        # Load intent data if available (Stage0)
        intent_used = False
        intent_data = None
        if state.user_vector is not None:
            # User vector was initialized from intent, try to load intent data
            loaded_intent = self.memory.load_intent(self.stage3_user_id)
            logger.info(f"Stage4 start_iteration: Loaded intent data: {loaded_intent is not None}")
            if loaded_intent:
                intent_used = True
                logger.info(f"Stage4 start_iteration: Intent keys: {loaded_intent.keys()}")
                # Flatten the structure: merge 'intent' fields into top level for frontend
                intent_data = {
                    'u_llm': loaded_intent.get('u_llm', []),
                    'business': loaded_intent.get('intent', {}).get('business'),
                    'pattern_preference': loaded_intent.get('intent', {}).get('pattern_preference', []),
                    'importance': loaded_intent.get('intent', {}).get('importance', {}),
                    'risk': loaded_intent.get('intent', {}).get('risk', {}),
                    'created_at': loaded_intent.get('created_at')
                }
                # Ensure u_llm is a list (not numpy array)
                if hasattr(intent_data['u_llm'], 'tolist'):
                    intent_data['u_llm'] = intent_data['u_llm'].tolist()
                logger.info(f"Stage4 start_iteration: Flattened intent_data - business: {intent_data.get('business')}, patterns: {len(intent_data.get('pattern_preference', []))}")
        else:
            logger.info("Stage4 start_iteration: state.user_vector is None, skipping intent data loading")
        
        result = {
            'session_id': session_id,
            'round': round_result['round'],
            'total_rounds': iteration_rounds,
            'patterns': round_result['patterns'],
            'rules': round_result['rules'],
            'similarity_scores': round_result.get('similarity_scores', {}),
            'is_final': round_result['is_final'],
            'params': params,
            'intent_used': intent_used,  # Stage0: Whether intent understanding was used
            'intent_data': intent_data  # Stage0: Intent understanding data
        }
        logger.info(f"Stage4 start_iteration: Returning result with intent_used={intent_used}, intent_data keys={list(intent_data.keys()) if intent_data else None}")
        return result
    
    def next_iteration(self, session_id: str, feedback: Dict[str, List[str]]) -> Dict[str, Any]:
        """Process feedback, train model, and return next iteration round (Stage4).
        
        Args:
            session_id: Iteration session ID
            feedback: Dictionary with 'positive' and 'negative' pattern lists
            
        Returns:
            Dictionary containing next round results
        """
        if session_id not in self.iteration_sessions:
            raise ValueError(f"Invalid session_id: {session_id}")
        
        session = self.iteration_sessions[session_id]
        state = session['state']
        params = session['params']
        
        # Update user vector and train model
        training_result = self.iteration_manager.update_and_train(
            state=state,
            feedback=feedback,
            trainer=self.trainer,
            training_epochs=5  # Quick training for each round
        )
        
        # IMPORTANT: Update preference_model reference after training
        # The trainer.model is updated during training, so we need to sync it
        if self.stage3_use_contrastive and self.trainer is not None and self.trainer.model is not None:
            self.preference_model = self.trainer.model
            logger.debug("Updated preference_model reference after training")
        
        # Load intent data if available (Stage0)
        intent_used = False
        intent_data = None
        if state.user_vector is not None:
            loaded_intent = self.memory.load_intent(self.stage3_user_id)
            if loaded_intent:
                intent_used = True
                # Flatten the structure: merge 'intent' fields into top level for frontend
                intent_data = {
                    'u_llm': loaded_intent.get('u_llm', []),
                    'business': loaded_intent.get('intent', {}).get('business'),
                    'pattern_preference': loaded_intent.get('intent', {}).get('pattern_preference', []),
                    'importance': loaded_intent.get('intent', {}).get('importance', {}),
                    'risk': loaded_intent.get('intent', {}).get('risk', {}),
                    'created_at': loaded_intent.get('created_at')
                }
                # Ensure u_llm is a list (not numpy array)
                if hasattr(intent_data['u_llm'], 'tolist'):
                    intent_data['u_llm'] = intent_data['u_llm'].tolist()
        
        # If not final, run next round
        if not state.is_complete():
            round_result = self.iteration_manager.run_one_iteration_step(
                state, params, trainer=self.trainer
            )
            
            return {
                'round': round_result['round'],
                'total_rounds': state.max_rounds,
                'patterns': round_result['patterns'],
                'rules': round_result['rules'],
                'similarity_scores': round_result.get('similarity_scores', {}),
                'training_result': training_result,
                'is_final': round_result['is_final'],
                'intent_used': intent_used,  # Stage0: Whether intent understanding was used
                'intent_data': intent_data  # Stage0: Intent understanding data
            }
        else:
            # Final round, return final results
            final_result = self.iteration_manager._final_rank(state, params)
            
            return {
                'round': state.current_round,
                'total_rounds': state.max_rounds,
                'patterns': final_result['patterns'],
                'rules': final_result['rules'],
                'similarity_scores': final_result.get('similarity_scores', {}),
                'training_result': training_result,
                'is_final': True,
                'intent_used': intent_used,  # Stage0: Whether intent understanding was used
                'intent_data': intent_data  # Stage0: Intent understanding data
            }
    
    def finalize_iteration(self, session_id: str) -> Dict[str, Any]:
        """Get final results after all iterations (Stage4).
        
        Args:
            session_id: Iteration session ID
            
        Returns:
            Dictionary containing final results and explanation
        """
        if session_id not in self.iteration_sessions:
            raise ValueError(f"Invalid session_id: {session_id}")
        
        session = self.iteration_sessions[session_id]
        state = session['state']
        params = session['params']
        query = session['query']
        
        # Get final ranking
        final_result = self.iteration_manager._final_rank(state, params)
        
        # Generate explanation
        explanation = self._generate_explanation(
            params,
            final_result['patterns'],
            final_result['rules']
        )
        
        # Build iteration history
        iteration_history = []
        for round_data in state.history:
            iteration_history.append({
                'round': round_data['round'],
                'total_patterns': round_data.get('total_patterns', 0),
                'total_rules': round_data.get('total_rules', 0),
                'user_vector_norm': round_data.get('user_vector_norm', 0.0)
            })
        
        # Load intent data if available (Stage0)
        intent_used = False
        intent_data = None
        if state.user_vector is not None:
            loaded_intent = self.memory.load_intent(self.stage3_user_id)
            if loaded_intent:
                intent_used = True
                # Flatten the structure: merge 'intent' fields into top level for frontend
                intent_data = {
                    'u_llm': loaded_intent.get('u_llm', []),
                    'business': loaded_intent.get('intent', {}).get('business'),
                    'pattern_preference': loaded_intent.get('intent', {}).get('pattern_preference', []),
                    'importance': loaded_intent.get('intent', {}).get('importance', {}),
                    'risk': loaded_intent.get('intent', {}).get('risk', {}),
                    'created_at': loaded_intent.get('created_at')
                }
                # Ensure u_llm is a list (not numpy array)
                if hasattr(intent_data['u_llm'], 'tolist'):
                    intent_data['u_llm'] = intent_data['u_llm'].tolist()
        
        # Clean up session
        del self.iteration_sessions[session_id]
        
        return {
            'patterns': final_result['patterns'],
            'rules': final_result['rules'],
            'similarity_scores': final_result.get('similarity_scores', {}),
            'explanation': explanation,
            'iteration_history': iteration_history,
            'params': params,
            'intent_used': intent_used,  # Stage0: Whether intent understanding was used
            'intent_data': intent_data  # Stage0: Intent understanding data
        }

