"""Evaluation metrics for preference learning."""

import numpy as np
from typing import List, Dict, Any
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import logging

logger = logging.getLogger(__name__)


def accuracy(y_true: Dict[int, int], y_pred: Dict[int, int]) -> float:
    """Compute accuracy.
    
    Args:
        y_true: Ground truth labels (dict mapping pattern_id to label)
        y_pred: Predicted labels (dict mapping pattern_id to label)
        
    Returns:
        Accuracy score
    """
    # Convert to lists in same order
    pattern_ids = sorted(set(y_true.keys()) & set(y_pred.keys()))
    true_labels = [y_true[pid] for pid in pattern_ids]
    pred_labels = [y_pred[pid] for pid in pattern_ids]
    
    return accuracy_score(true_labels, pred_labels)


def precision(y_true: Dict[int, int], y_pred: Dict[int, int]) -> float:
    """Compute precision.
    
    Args:
        y_true: Ground truth labels
        y_pred: Predicted labels
        
    Returns:
        Precision score
    """
    pattern_ids = sorted(set(y_true.keys()) & set(y_pred.keys()))
    true_labels = [y_true[pid] for pid in pattern_ids]
    pred_labels = [y_pred[pid] for pid in pattern_ids]
    
    return precision_score(true_labels, pred_labels, zero_division=0.0)


def recall(y_true: Dict[int, int], y_pred: Dict[int, int]) -> float:
    """Compute recall.
    
    Args:
        y_true: Ground truth labels
        y_pred: Predicted labels
        
    Returns:
        Recall score
    """
    pattern_ids = sorted(set(y_true.keys()) & set(y_pred.keys()))
    true_labels = [y_true[pid] for pid in pattern_ids]
    pred_labels = [y_pred[pid] for pid in pattern_ids]
    
    return recall_score(true_labels, pred_labels, zero_division=0.0)


def f1(y_true: Dict[int, int], y_pred: Dict[int, int]) -> float:
    """Compute F1 score.
    
    Args:
        y_true: Ground truth labels
        y_pred: Predicted labels
        
    Returns:
        F1 score
    """
    pattern_ids = sorted(set(y_true.keys()) & set(y_pred.keys()))
    true_labels = [y_true[pid] for pid in pattern_ids]
    pred_labels = [y_pred[pid] for pid in pattern_ids]
    
    return f1_score(true_labels, pred_labels, zero_division=0.0)


def compute_threshold(scores: List[float]) -> float:
    """Compute dynamic threshold using median.
    
    IMPORTANT: Do NOT use fixed threshold.
    
    Args:
        scores: List of prediction scores
        
    Returns:
        Threshold value (median of scores)
    """
    if not scores:
        return 0.0
    
    #threshold = np.median(scores)
    threshold = 0.7
    logger.debug(f"Computed threshold: {threshold:.4f} (median of {len(scores)} scores)")
    return threshold

