"""Pattern sampling strategies for diverse recommendation."""

import numpy as np
import random
from typing import List, Dict, Any, Optional
import logging

logger = logging.getLogger(__name__)


class DiversitySampler:
    """Sampler for diverse pattern selection to ensure balanced training data."""
    
    @staticmethod
    def sample_diverse_patterns(
        patterns: List[Dict[str, Any]],
        scores: List[float],
        top_k: int = 5,
        strategy: str = "mixed"
    ) -> List[Dict[str, Any]]:
        """Sample diverse patterns to ensure balanced training data.
        
        Args:
            patterns: List of pattern dictionaries
            scores: List of scores for each pattern
            top_k: Number of patterns to select
            strategy: Sampling strategy
                - "top": Select top-k highest scores (original)
                - "mixed": Mix of top, middle, and random
                - "stratified": Stratified sampling by score ranges
                - "uncertainty": Select patterns with high uncertainty (high variance)
        
        Returns:
            List of selected pattern dictionaries
        """
        if len(patterns) != len(scores):
            raise ValueError(f"Patterns and scores length mismatch: {len(patterns)} vs {len(scores)}")
        
        if len(patterns) <= top_k:
            return patterns
        
        if strategy == "top":
            return DiversitySampler._sample_top(patterns, scores, top_k)
        elif strategy == "mixed":
            return DiversitySampler._sample_mixed(patterns, scores, top_k)
        elif strategy == "stratified":
            return DiversitySampler._sample_stratified(patterns, scores, top_k)
        elif strategy == "uncertainty":
            return DiversitySampler._sample_uncertainty(patterns, scores, top_k)
        else:
            logger.warning(f"Unknown strategy {strategy}, using 'mixed'")
            return DiversitySampler._sample_mixed(patterns, scores, top_k)
    
    @staticmethod
    def _sample_top(patterns: List[Dict[str, Any]], scores: List[float], top_k: int) -> List[Dict[str, Any]]:
        """Original strategy: select top-k highest scores."""
        ranked = sorted(zip(patterns, scores), key=lambda x: x[1], reverse=True)
        return [p for p, _ in ranked[:top_k]]
    
    @staticmethod
    def _sample_mixed(patterns: List[Dict[str, Any]], scores: List[float], top_k: int) -> List[Dict[str, Any]]:
        """Mixed strategy: top + middle + random.
        
        Distribution:
        - Top 40%: Highest scores (likely positive)
        - Middle 30%: Medium scores (uncertain)
        - Random 30%: Random selection (exploration)
        """
        n = len(patterns)
        ranked = sorted(zip(patterns, scores), key=lambda x: x[1], reverse=True)
        
        n_top = max(1, int(top_k * 0.4))  # 40% from top
        n_middle = max(1, int(top_k * 0.3))  # 30% from middle
        n_random = top_k - n_top - n_middle  # Remaining from random
        
        selected = []
        selected_indices = set()
        
        # Top patterns
        for i in range(n_top):
            selected.append(ranked[i][0])
            selected_indices.add(i)
        
        # Middle patterns (from middle 50% of ranked list)
        middle_start = n // 4
        middle_end = 3 * n // 4
        middle_candidates = list(range(middle_start, min(middle_end, n)))
        middle_candidates = [i for i in middle_candidates if i not in selected_indices]
        
        if middle_candidates:
            sampled_middle = random.sample(middle_candidates, min(n_middle, len(middle_candidates)))
            for idx in sampled_middle:
                selected.append(ranked[idx][0])
                selected_indices.add(idx)
        
        # Random patterns (from all remaining)
        remaining = [i for i in range(n) if i not in selected_indices]
        if remaining:
            sampled_random = random.sample(remaining, min(n_random, len(remaining)))
            for idx in sampled_random:
                selected.append(ranked[idx][0])
        
        # Shuffle to avoid bias
        random.shuffle(selected)
        
        logger.info(f"Mixed sampling: {n_top} top + {len(selected) - n_top - n_random} middle + {n_random} random")
        return selected
    
    @staticmethod
    def _sample_stratified(patterns: List[Dict[str, Any]], scores: List[float], top_k: int) -> List[Dict[str, Any]]:
        """Stratified sampling: divide score range into strata and sample from each."""
        n = len(patterns)
        ranked = sorted(zip(patterns, scores), key=lambda x: x[1], reverse=True)
        
        # Divide into 3 strata: high, medium, low
        n_strata = 3
        samples_per_stratum = max(1, top_k // n_strata)
        remainder = top_k - samples_per_stratum * n_strata
        
        selected = []
        stratum_size = n // n_strata
        
        for stratum in range(n_strata):
            start = stratum * stratum_size
            end = (stratum + 1) * stratum_size if stratum < n_strata - 1 else n
            
            # Add remainder to last stratum
            if stratum == n_strata - 1:
                end = n
            
            stratum_candidates = list(range(start, end))
            n_samples = samples_per_stratum + (remainder if stratum == n_strata - 1 else 0)
            
            if stratum_candidates:
                sampled = random.sample(stratum_candidates, min(n_samples, len(stratum_candidates)))
                for idx in sampled:
                    selected.append(ranked[idx][0])
        
        random.shuffle(selected)
        logger.info(f"Stratified sampling: {samples_per_stratum} per stratum")
        return selected
    
    @staticmethod
    def _sample_uncertainty(patterns: List[Dict[str, Any]], scores: List[float], top_k: int) -> List[Dict[str, Any]]:
        """Uncertainty-based sampling: select patterns near decision boundary.
        
        This selects patterns with scores close to the median (uncertain predictions).
        """
        scores_array = np.array(scores)
        median_score = np.median(scores_array)
        
        # Calculate distance from median
        distances = np.abs(scores_array - median_score)
        
        # Select patterns closest to median (most uncertain)
        uncertain_indices = np.argsort(distances)[:top_k]
        
        selected = [patterns[i] for i in uncertain_indices]
        logger.info(f"Uncertainty sampling: selected {top_k} patterns near median score {median_score:.4f}")
        return selected

