"""Trainer for contrastive learning with triplet loss."""

import os
import numpy as np
from typing import Optional, Dict, Any
import logging

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader, TensorDataset
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    logging.warning("PyTorch not available, PreferenceTrainer will be disabled")

from embedding.encoder import PreferenceEncoder
from learning.dataset import PreferenceDataset
from memory.store import MemoryStore
from learning.embedder import PatternEmbedder

logger = logging.getLogger(__name__)


class TripletLoss(nn.Module if TORCH_AVAILABLE else object):
    """Triplet loss for contrastive learning.
    
    Loss = max(0, d(anchor, positive) - d(anchor, negative) + margin)
    where d is Euclidean distance.
    """
    
    def __init__(self, margin: float = 0.3):
        """Initialize triplet loss.
        
        Args:
            margin: Margin for triplet loss
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required for TripletLoss")
        
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, 
                negative: torch.Tensor) -> torch.Tensor:
        """Compute triplet loss.
        
        Args:
            anchor: Anchor vectors (batch_size, dim)
            positive: Positive vectors (batch_size, dim)
            negative: Negative vectors (batch_size, dim)
            
        Returns:
            Loss value
        """
        # 计算距离
        d_pos = torch.nn.functional.pairwise_distance(anchor, positive)
        d_neg = torch.nn.functional.pairwise_distance(anchor, negative)
        
        # Triplet loss
        loss = torch.clamp(d_pos - d_neg + self.margin, min=0.0)
        
        return loss.mean()


class CosineTripletLoss(nn.Module if TORCH_AVAILABLE else object):
    """Cosine-based triplet loss.
    
    Loss = max(0, -cos(anchor, positive) + cos(anchor, negative) + margin)
    """
    
    def __init__(self, margin: float = 0.3):
        """Initialize cosine triplet loss.
        
        Args:
            margin: Margin for triplet loss
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required for CosineTripletLoss")
        
        super(CosineTripletLoss, self).__init__()
        self.margin = margin
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, 
                negative: torch.Tensor) -> torch.Tensor:
        """Compute cosine triplet loss.
        
        Args:
            anchor: Anchor vectors (batch_size, dim)
            positive: Positive vectors (batch_size, dim)
            negative: Negative vectors (batch_size, dim)
            
        Returns:
            Loss value
        """
        # 计算余弦相似度
        cos_pos = torch.nn.functional.cosine_similarity(anchor, positive)
        cos_neg = torch.nn.functional.cosine_similarity(anchor, negative)
        
        # Cosine triplet loss
        loss = torch.clamp(-cos_pos + cos_neg + self.margin, min=0.0)
        
        return loss.mean()


class PreferenceTrainer:
    """Trainer for preference encoder using contrastive learning."""
    
    def __init__(self, 
                 memory_store: MemoryStore,
                 embedder: PatternEmbedder,
                 model: Optional[PreferenceEncoder] = None,
                 input_dim: int = 384,
                 hidden_dim: int = 256,
                 margin: float = 0.3,
                 learning_rate: float = 0.001,
                 use_cosine_loss: bool = True,
                 user_id: str = "user_001"):
        """Initialize preference trainer.
        
        Args:
            memory_store: Memory store instance
            embedder: Pattern embedder
            model: Optional pre-trained preference encoder
            input_dim: Input embedding dimension
            hidden_dim: Hidden layer dimension
            margin: Margin for triplet loss
            learning_rate: Learning rate
            use_cosine_loss: Whether to use cosine-based triplet loss
            user_id: User identifier
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required for PreferenceTrainer")
        
        self.memory_store = memory_store
        self.embedder = embedder
        self.user_id = user_id
        self.margin = margin
        self.learning_rate = learning_rate
        
        # 初始化模型
        if model is None:
            self.model = PreferenceEncoder(input_dim=input_dim, hidden_dim=hidden_dim)
        else:
            self.model = model
        
        # 初始化损失函数
        if use_cosine_loss:
            self.criterion = CosineTripletLoss(margin=margin)
        else:
            self.criterion = TripletLoss(margin=margin)
        
        # 初始化优化器
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        
        logger.info(f"PreferenceTrainer initialized with margin={margin}, lr={learning_rate}")
    
    def train(self, epochs: int = 10, batch_size: int = 32, 
              save_path: Optional[str] = None) -> Dict[str, Any]:
        """Train the preference encoder.
        
        Args:
            epochs: Number of training epochs
            batch_size: Batch size
            save_path: Path to save the trained model
            
        Returns:
            Training history dictionary
        """
        # 创建数据集
        dataset = PreferenceDataset(self.memory_store, self.embedder, self.user_id)
        triplets = dataset.get_triplets()
        
        if len(triplets) == 0:
            logger.warning("No triplets available for training")
            return {"loss": [], "epochs": 0}
        
        # 转换为tensor
        anchors = torch.FloatTensor([t[0] for t in triplets])
        positives = torch.FloatTensor([t[1] for t in triplets])
        negatives = torch.FloatTensor([t[2] for t in triplets])
        
        # 创建数据加载器
        tensor_dataset = TensorDataset(anchors, positives, negatives)
        dataloader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)
        
        # 训练历史
        history = {"loss": []}
        
        logger.info(f"Starting training with {len(triplets)} triplets, {epochs} epochs")
        
        self.model.train()
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            n_batches = 0
            
            for batch_anchors, batch_positives, batch_negatives in dataloader:
                # 前向传播
                anchor_proj = self.model(batch_anchors)
                positive_proj = self.model(batch_positives)
                negative_proj = self.model(batch_negatives)
                
                # 计算损失
                loss = self.criterion(anchor_proj, positive_proj, negative_proj)
                
                # 反向传播
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item()
                n_batches += 1
            
            avg_loss = epoch_loss / n_batches if n_batches > 0 else 0.0
            history["loss"].append(avg_loss)
            
            # 输出训练进度（同时输出到日志和控制台）
            progress_msg = f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}"
            logger.info(progress_msg)
            print(f"  {progress_msg}")
            
            # 显示loss变化趋势
            if epoch > 0:
                loss_change = history["loss"][epoch-1] - avg_loss
                trend = "↓" if loss_change > 0 else "↑" if loss_change < 0 else "→"
                print(f"    Loss变化: {trend} {abs(loss_change):.4f}")
        
        # 保存模型
        if save_path:
            self.save_model(save_path)
        
        history["epochs"] = epochs
        return history
    
    def save_model(self, path: str):
        """Save the trained model.
        
        Args:
            path: Path to save the model
        """
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'input_dim': self.model.input_dim,
            'hidden_dim': self.model.hidden_dim,
            'output_dim': self.model.output_dim,
        }, path)
        logger.info(f"Model saved to {path}")
    
    def load_model(self, path: str):
        """Load a trained model.
        
        Args:
            path: Path to load the model from
        """
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logger.info(f"Model loaded from {path}")

