"""Plotting utilities for experiment results."""

import matplotlib.pyplot as plt
import json
import os
from typing import List, Optional, Dict
import logging

logger = logging.getLogger(__name__)


def plot_learning_curve(
    baseline: List[float],
    ours: List[float],
    output_path: Optional[str] = None,
    title: str = "Learning Curve Comparison",
    xlabel: str = "Interaction Round",
    ylabel: str = "Accuracy",
    preference_weighted: Optional[List[float]] = None
):
    """Plot learning curve comparing baseline, contrastive, and optionally preference-weighted methods.
    
    Args:
        baseline: List of accuracy scores for baseline method
        ours: List of accuracy scores for contrastive method
        output_path: Path to save plot (optional)
        title: Plot title
        xlabel: X-axis label
        ylabel: Y-axis label
        preference_weighted: Optional list of accuracy scores for preference-weighted method (2025-style)
    """
    rounds = list(range(1, len(baseline) + 1))
    
    plt.figure(figsize=(10, 6))
    plt.plot(rounds, baseline, 'b-o', label='Baseline', linewidth=2, markersize=6)
    plt.plot(rounds, ours, 'r-s', label='Contrastive', linewidth=2, markersize=6)
    if preference_weighted is not None and len(preference_weighted) == len(baseline):
        plt.plot(rounds, preference_weighted, 'g-^', label='Preference-Weighted (2025)', linewidth=2, markersize=6)
    
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.ylim([0, 1])
    plt.xlim([1, len(baseline)])
    
    # Add value annotations
    # for i, (b, o) in enumerate(zip(baseline, ours)):
    #     if i == len(baseline) - 1:  # Only annotate last point
    #         plt.annotate(f'{b:.3f}', (i+1, b), textcoords="offset points", 
    #                     xytext=(0,10), ha='center', fontsize=9)
    #         plt.annotate(f'{o:.3f}', (i+1, o), textcoords="offset points", 
    #                     xytext=(0,-15), ha='center', fontsize=9)
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        logger.info(f"Plot saved to {output_path}")
    else:
        plt.show()
    
    plt.close()


def plot_all_metrics_func(
    baseline_metrics: Dict[str, List[float]],
    contrastive_metrics: Dict[str, List[float]],
    output_path: Optional[str] = None,
    preference_weighted_metrics: Optional[Dict[str, List[float]]] = None
):
    """Plot all metrics (accuracy, precision, recall, F1) in subplots.
    
    Args:
        baseline_metrics: Dictionary with 'accuracy', 'precision', 'recall', 'f1' lists
        contrastive_metrics: Dictionary with 'accuracy', 'precision', 'recall', 'f1' lists
        output_path: Path to save plot (optional)
        preference_weighted_metrics: Optional third method metrics (2025-style)
    """
    rounds = list(range(1, len(baseline_metrics['accuracy']) + 1))
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Evaluation Metrics Comparison', fontsize=16, fontweight='bold')
    
    metrics = [
        ('accuracy', 'Accuracy', axes[0, 0]),
        ('precision', 'Precision', axes[0, 1]),
        ('recall', 'Recall', axes[1, 0]),
        ('f1', 'F1 Score', axes[1, 1])
    ]
    
    for metric_key, metric_name, ax in metrics:
        baseline = baseline_metrics[metric_key]
        contrastive = contrastive_metrics[metric_key]
        
        ax.plot(rounds, baseline, 'b-o', label='Baseline', linewidth=2, markersize=5)
        ax.plot(rounds, contrastive, 'r-s', label='Contrastive', linewidth=2, markersize=5)
        if preference_weighted_metrics is not None and metric_key in preference_weighted_metrics:
            pw = preference_weighted_metrics[metric_key]
            if len(pw) == len(baseline):
                ax.plot(rounds, pw, 'g-^', label='Preference-Weighted (2025)', linewidth=2, markersize=5)
        
        ax.set_xlabel('Interaction Round', fontsize=10)
        ax.set_ylabel(metric_name, fontsize=10)
        ax.set_title(metric_name, fontsize=12, fontweight='bold')
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1])
        ax.set_xlim([1, len(baseline)])
        
        # Add final value annotations
        if len(baseline) > 0:
            ax.annotate(f'{baseline[-1]:.3f}', (len(baseline), baseline[-1]), 
                       textcoords="offset points", xytext=(5, 0), ha='left', fontsize=8)
            ax.annotate(f'{contrastive[-1]:.3f}', (len(contrastive), contrastive[-1]), 
                       textcoords="offset points", xytext=(5, -12), ha='left', fontsize=8)
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        logger.info(f"All metrics plot saved to {output_path}")
    else:
        plt.show()
    
    plt.close()


def plot_from_json(
    baseline_path: str,
    contrastive_path: str,
    output_path: Optional[str] = None,
    plot_all_metrics: bool = False,
    preference_weighted_path: Optional[str] = None
):
    """Load results from JSON files and plot.
    
    Args:
        baseline_path: Path to baseline metrics JSON file
        contrastive_path: Path to contrastive metrics JSON file
        output_path: Path to save plot (optional)
        plot_all_metrics: If True, plot all metrics in subplots; if False, plot only accuracy
        preference_weighted_path: Optional path to preference-weighted metrics JSON (2025-style third method)
    """
    with open(baseline_path, 'r') as f:
        baseline_data = json.load(f)
    
    with open(contrastive_path, 'r') as f:
        contrastive_data = json.load(f)
    
    # Handle both old format (list) and new format (dict)
    if isinstance(baseline_data, list):
        baseline_metrics = {'accuracy': baseline_data}
        contrastive_metrics = {'accuracy': contrastive_data}
    else:
        baseline_metrics = baseline_data
        contrastive_metrics = contrastive_data

    pw_metrics = None
    if preference_weighted_path and os.path.isfile(preference_weighted_path):
        with open(preference_weighted_path, 'r') as f:
            pw_data = json.load(f)
        pw_metrics = pw_data if isinstance(pw_data, dict) else {'accuracy': pw_data}
    
    if plot_all_metrics and 'precision' in baseline_metrics:
        from experiment.plotter import plot_all_metrics_func
        plot_all_metrics_func(baseline_metrics, contrastive_metrics, output_path, preference_weighted_metrics=pw_metrics)
    else:
        pw_acc = pw_metrics['accuracy'] if pw_metrics else None
        plot_learning_curve(
            baseline_metrics['accuracy'], 
            contrastive_metrics['accuracy'], 
            output_path=output_path,
            preference_weighted=pw_acc
        )


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Plot learning curves")
    parser.add_argument("--baseline", type=str, default="results/metrics_baseline.json",
                       help="Path to baseline metrics JSON")
    parser.add_argument("--contrastive", type=str, default="results/metrics_contrastive.json",
                       help="Path to contrastive metrics JSON")
    parser.add_argument("--preference-weighted", type=str, default=None,
                       help="Path to preference-weighted metrics JSON (optional third method)")
    parser.add_argument("--output", type=str, default="results/learning_curve.png",
                       help="Output path for plot")
    parser.add_argument("--accuracy-only", action="store_true",
                       help="Only plot accuracy (single figure); default when not using --all-metrics")
    parser.add_argument("--all-metrics", action="store_true",
                       help="Plot all metrics (accuracy, precision, recall, F1) in 2x2 subplots")
    
    args = parser.parse_args()
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # 仅当未指定 --all-metrics 时画单独 accuracy 图；--accuracy-only 则强制只画 accuracy
    plot_all = args.all_metrics and not args.accuracy_only
    plot_from_json(
        args.baseline, args.contrastive, args.output,
        plot_all_metrics=plot_all,
        preference_weighted_path=args.preference_weighted
    )

