"""Pattern encoder with learnable preference projection."""

import numpy as np
from typing import List, Union
import logging

try:
    import torch
    import torch.nn as nn
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    logging.warning("PyTorch not available, PreferenceEncoder will be disabled")

from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)


class PatternEncoder:
    """Base pattern encoder using sentence transformers (frozen)."""
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        """Initialize pattern encoder.
        
        Args:
            model_name: Name of the sentence transformer model
        """
        self.model = SentenceTransformer(model_name)
        self.embedding_dim = self.model.get_sentence_embedding_dimension()
        logger.info(f"PatternEncoder initialized with dimension {self.embedding_dim}")
    
    def encode(self, text: Union[str, List[str]]) -> np.ndarray:
        """Encode pattern text to vector.
        
        Args:
            text: Pattern text (e.g., "Park Zoo") or list of patterns
            
        Returns:
            Vector representation(s)
        """
        if isinstance(text, str):
            return self.model.encode(text)
        else:
            return self.model.encode(text)


class PreferenceEncoder(nn.Module if TORCH_AVAILABLE else object):
    """Learnable preference encoder for projecting embeddings to preference space.
    
    Architecture:
        Linear(dim -> hidden) -> ReLU -> Linear(hidden -> dim) -> Normalize
    """
    
    def __init__(self, input_dim: int = 384, hidden_dim: int = 256, output_dim: int = None):
        """Initialize preference encoder.
        
        Args:
            input_dim: Input embedding dimension
            hidden_dim: Hidden layer dimension
            output_dim: Output dimension (default: same as input_dim)
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required for PreferenceEncoder")
        
        super(PreferenceEncoder, self).__init__()
        
        if output_dim is None:
            output_dim = input_dim
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # 定义网络结构
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # 初始化权重
        self._initialize_weights()
        
        logger.info(f"PreferenceEncoder initialized: {input_dim} -> {hidden_dim} -> {output_dim}")
    
    def _initialize_weights(self):
        """Initialize network weights."""
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Projected and normalized tensor of shape (batch_size, output_dim)
        """
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        
        # L2 normalization
        x = torch.nn.functional.normalize(x, p=2, dim=1)
        
        return x
    
    def encode(self, x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
        """Encode input to preference space (for inference).
        
        Args:
            x: Input array or tensor
            
        Returns:
            Projected vector as numpy array
        """
        if isinstance(x, np.ndarray):
            x = torch.FloatTensor(x)
        
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        
        self.eval()
        with torch.no_grad():
            output = self.forward(x)
        
        if output.shape[0] == 1:
            return output.squeeze(0).numpy()
        else:
            return output.numpy()

