"""Dataset for contrastive learning with triplet structure."""

import numpy as np
from typing import List, Tuple, Dict, Any
import logging

from memory.store import MemoryStore
from learning.embedder import PatternEmbedder

logger = logging.getLogger(__name__)


class PreferenceDataset:
    """Dataset for triplet-based contrastive learning.
    
    Each sample is (anchor, positive, negative):
    - anchor: User preference vector
    - positive: Liked pattern vector
    - negative: Disliked pattern vector
    """
    
    def __init__(self, memory_store: MemoryStore, embedder: PatternEmbedder, user_id: str = "user_001"):
        """Initialize preference dataset.
        
        Args:
            memory_store: Memory store instance
            embedder: Pattern embedder for encoding patterns
            user_id: User identifier
        """
        self.memory_store = memory_store
        self.embedder = embedder
        self.user_id = user_id
        
        # 加载用户记忆
        self.user_memory = memory_store.load_user_memory()
        self.user_profile = self.user_memory.get(user_id)
        
        if self.user_profile is None:
            logger.warning(f"User {user_id} not found in memory")
            self.positive_vectors = []
            self.negative_vectors = []
        else:
            # 加载正负面向量
            self.positive_vectors = [np.array(v) for v in self.user_profile.get("positive", [])]
            self.negative_vectors = [np.array(v) for v in self.user_profile.get("negative", [])]
        
        logger.info(f"Dataset initialized: {len(self.positive_vectors)} positive, "
                   f"{len(self.negative_vectors)} negative vectors")
    
    def build_user_vector(self) -> np.ndarray:
        """Build user anchor vector.
        
        Returns:
            User preference vector (mean_pos - mean_neg)
        """
        if len(self.positive_vectors) == 0 and len(self.negative_vectors) == 0:
            # 如果没有反馈，返回零向量
            dim = self.embedder.model.get_sentence_embedding_dimension()
            return np.zeros(dim)
        
        # 计算正面向量均值
        if len(self.positive_vectors) > 0:
            mean_pos = np.mean(self.positive_vectors, axis=0)
        else:
            dim = self.positive_vectors[0].shape[0] if self.negative_vectors else 384
            mean_pos = np.zeros(dim)
        
        # 计算负面向量均值
        if len(self.negative_vectors) > 0:
            mean_neg = np.mean(self.negative_vectors, axis=0)
        else:
            dim = self.negative_vectors[0].shape[0] if self.positive_vectors else 384
            mean_neg = np.zeros(dim)
        
        # 用户向量 = 正面向量均值 - 负面向量均值
        user_vector = mean_pos - mean_neg
        
        return user_vector
    
    def get_triplets(self) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
        """Generate triplet samples.
        
        Returns:
            List of (anchor, positive, negative) triplets
        """
        if len(self.positive_vectors) == 0 or len(self.negative_vectors) == 0:
            logger.warning("Insufficient data for triplet generation")
            return []
        
        anchor = self.build_user_vector()
        triplets = []
        
        # 生成所有可能的 (anchor, positive, negative) 组合
        for pos_vec in self.positive_vectors:
            for neg_vec in self.negative_vectors:
                triplets.append((anchor, pos_vec, neg_vec))
        
        logger.info(f"Generated {len(triplets)} triplets")
        return triplets
    
    def __len__(self) -> int:
        """Get dataset size."""
        triplets = self.get_triplets()
        return len(triplets)
    
    def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Get a triplet sample.
        
        Args:
            idx: Sample index
            
        Returns:
            (anchor, positive, negative) triplet
        """
        triplets = self.get_triplets()
        if idx >= len(triplets):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(triplets)}")
        
        return triplets[idx]

