"""Service layer for encapsulating business logic."""

import logging
import os
from typing import Dict, Any, Optional, List
import numpy as np

from controller.manager import PipelineManager
from memory.store import MemoryStore
from learning.embedder import PatternEmbedder

logger = logging.getLogger(__name__)


class WebService:
    """Service layer for web API."""
    
    def __init__(self, manager: PipelineManager):
        """Initialize web service.
        
        Args:
            manager: PipelineManager instance
        """
        self.manager = manager
        self.memory = MemoryStore()
        self.embedder = PatternEmbedder()
    
    def run_query(self, query: str, iteration_rounds: Optional[int] = None) -> Dict[str, Any]:
        """Execute a query and return results.
        
        Args:
            query: Natural language query string
            iteration_rounds: Optional number of iteration rounds (Stage4)
            
        Returns:
            Dictionary containing query results
        """
        try:
            result = self.manager.process_query(query, iteration_rounds=iteration_rounds)
            # 清理结果，移除 numpy 数组（转换为列表）
            cleaned_result = self._clean_result(result)
            # 将 extracted_parameters 重命名为 params 以匹配 schema
            if 'extracted_parameters' in cleaned_result:
                cleaned_result['params'] = cleaned_result.pop('extracted_parameters')
            return cleaned_result
        except Exception as e:
            logger.error(f"Error processing query: {e}", exc_info=True)
            raise
    
    def add_feedback(self, pattern_id: int, pattern: list, feedback: str) -> Dict[str, Any]:
        """Add user feedback for a pattern.
        
        Args:
            pattern_id: Pattern identifier
            pattern: Pattern content (list of feature types)
            feedback: "positive" or "negative"
            
        Returns:
            Dictionary with feedback status and counts
        """
        try:
            user_id = self.manager.stage3_user_id
            
            # 编码模式为向量
            pattern_vec = self.embedder.encode_pattern(pattern)
            
            # 保存反馈（add_positive 和 add_negative 只需要 user_id 和 vector）
            if feedback == "positive":
                self.memory.add_positive(user_id, pattern_vec)
            elif feedback == "negative":
                self.memory.add_negative(user_id, pattern_vec)
            else:
                raise ValueError(f"Invalid feedback type: {feedback}")
            
            # 获取更新后的用户档案
            user_profile = self.memory.get_user_profile(user_id)
            positive_count = len(user_profile.get('positive', []))
            negative_count = len(user_profile.get('negative', []))
            
            return {
                "status": "ok",
                "message": f"Feedback saved: {feedback}",
                "positive_count": positive_count,
                "negative_count": negative_count
            }
        except Exception as e:
            logger.error(f"Error adding feedback: {e}", exc_info=True)
            raise
    
    def run_training(self, epochs: int = 10, learning_rate: float = 0.001, 
                    batch_size: int = 32, margin: float = 0.3) -> Dict[str, Any]:
        """Run preference model training.
        
        Args:
            epochs: Number of training epochs
            learning_rate: Learning rate
            batch_size: Batch size
            margin: Margin for triplet loss
            
        Returns:
            Dictionary with training status and loss history
        """
        try:
            if not self.manager.stage3_enabled:
                return {
                    "status": "error",
                    "message": "Stage3 is not enabled",
                    "loss_history": [],
                    "epochs": 0
                }
            
            if not self.manager.stage3_use_contrastive:
                return {
                    "status": "error",
                    "message": "Contrastive learning is not enabled",
                    "loss_history": [],
                    "epochs": 0
                }
            
            if self.manager.trainer is None:
                return {
                    "status": "error",
                    "message": "Trainer is not initialized",
                    "loss_history": [],
                    "epochs": 0
                }
            
            # 更新训练器参数
            import torch.optim as optim
            self.manager.trainer.optimizer = optim.Adam(
                self.manager.trainer.model.parameters(), 
                lr=learning_rate
            )
            
            # 执行训练
            history = self.manager.trainer.train(
                epochs=epochs,
                batch_size=batch_size,
                save_path=self.manager.stage3_model_path
            )
            
            # 保存训练历史
            user_profile = self.memory.get_user_profile(self.manager.stage3_user_id)
            if user_profile:
                positive_count = len(user_profile.get('positive', []))
                negative_count = len(user_profile.get('negative', []))
                total_feedback = positive_count + negative_count
            else:
                total_feedback = 0
            
            # 调用 manager 的保存历史方法
            self.manager._save_training_history(total_feedback, epochs, history)
            
            # 重新加载模型（如果存在）
            if hasattr(self.manager, 'stage3_model_path') and os.path.exists(self.manager.stage3_model_path):
                # 重新初始化 preference_model
                try:
                    from embedding.encoder import PreferenceEncoder
                    import yaml
                    # 读取配置文件获取 hidden_dim
                    config_path = "config/config.yaml"
                    with open(config_path, 'r', encoding='utf-8') as f:
                        config = yaml.safe_load(f)
                    input_dim = self.embedder.model.get_sentence_embedding_dimension()
                    hidden_dim = config.get('stage3', {}).get('hidden_dim', 256)
                    self.manager.preference_model = PreferenceEncoder(
                        input_dim=input_dim,
                        hidden_dim=hidden_dim
                    )
                    # 加载权重
                    import torch
                    checkpoint = torch.load(self.manager.stage3_model_path, map_location='cpu')
                    self.manager.preference_model.load_state_dict(checkpoint['model_state_dict'])
                    self.manager.preference_model.eval()
                    logger.info("Preference model reloaded after training")
                except Exception as e:
                    logger.warning(f"Failed to reload preference model: {e}")
            
            return {
                "status": "ok",
                "message": "Training completed successfully",
                "loss_history": history.get("loss", []),
                "epochs": epochs
            }
        except Exception as e:
            logger.error(f"Error during training: {e}", exc_info=True)
            return {
                "status": "error",
                "message": str(e),
                "loss_history": [],
                "epochs": 0
            }
    
    def get_training_history(self) -> Dict[str, Any]:
        """Get training history from logs.
        
        Returns:
            Dictionary containing training history
        """
        import json
        import os
        
        history_path = "logs/training_history.json"
        if not os.path.exists(history_path):
            return {"trainings": []}
        
        try:
            with open(history_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data
        except Exception as e:
            logger.error(f"Error reading training history: {e}", exc_info=True)
            return {"trainings": []}
    
    def _clean_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
        """Clean result dictionary, converting numpy arrays to lists.
        
        Args:
            result: Raw result dictionary
            
        Returns:
            Cleaned result dictionary
        """
        cleaned = {}
        for key, value in result.items():
            if isinstance(value, np.ndarray):
                cleaned[key] = value.tolist()
            elif isinstance(value, dict):
                cleaned[key] = self._clean_dict(value)
            elif isinstance(value, list):
                cleaned[key] = self._clean_list(value)
            else:
                cleaned[key] = value
        return cleaned
    
    def _clean_dict(self, d: Dict[str, Any]) -> Dict[str, Any]:
        """Recursively clean dictionary."""
        cleaned = {}
        for k, v in d.items():
            if isinstance(v, np.ndarray):
                cleaned[k] = v.tolist()
            elif isinstance(v, dict):
                cleaned[k] = self._clean_dict(v)
            elif isinstance(v, list):
                cleaned[k] = self._clean_list(v)
            else:
                cleaned[k] = v
        return cleaned
    
    def _clean_list(self, l: list) -> list:
        """Recursively clean list."""
        cleaned = []
        for item in l:
            if isinstance(item, np.ndarray):
                cleaned.append(item.tolist())
            elif isinstance(item, dict):
                cleaned.append(self._clean_dict(item))
            elif isinstance(item, list):
                cleaned.append(self._clean_list(item))
            else:
                cleaned.append(item)
        return cleaned
    
    def start_iteration(self, query: str, iteration_rounds: int) -> Dict[str, Any]:
        """Start an interactive iteration session (Stage4).
        
        Args:
            query: User query
            iteration_rounds: Number of iteration rounds
            
        Returns:
            Dictionary containing session_id and first round results
        """
        try:
            result = self.manager.start_iteration(query, iteration_rounds)
            cleaned_result = self._clean_result(result)
            return cleaned_result
        except Exception as e:
            logger.error(f"Error starting iteration: {e}", exc_info=True)
            raise
    
    def next_iteration(self, session_id: str, feedback: Dict[str, List[str]]) -> Dict[str, Any]:
        """Process feedback, train model, and return next iteration round (Stage4).
        
        Args:
            session_id: Iteration session ID
            feedback: Dictionary with 'positive' and 'negative' pattern lists
            
        Returns:
            Dictionary containing next round results
        """
        try:
            result = self.manager.next_iteration(session_id, feedback)
            cleaned_result = self._clean_result(result)
            return cleaned_result
        except Exception as e:
            logger.error(f"Error in next iteration: {e}", exc_info=True)
            raise
    
    def finalize_iteration(self, session_id: str) -> Dict[str, Any]:
        """Get final results after all iterations (Stage4).
        
        Args:
            session_id: Iteration session ID
            
        Returns:
            Dictionary containing final results and explanation
        """
        try:
            result = self.manager.finalize_iteration(session_id)
            cleaned_result = self._clean_result(result)
            return cleaned_result
        except Exception as e:
            logger.error(f"Error finalizing iteration: {e}", exc_info=True)
            raise

