"""Experiment runner for comparing baseline and contrastive methods."""

import logging
import json
import os
from pathlib import Path
from typing import Dict, List, Any

from controller.manager import PipelineManager
from experiment.simulator import UserSimulator
from experiment.evaluator import InteractionEvaluator

logger = logging.getLogger(__name__)


def run_experiment(
    baseline_config: str = "config/config_baseline.yaml",
    contrastive_config: str = "config/config_contrastive.yaml",
    preference_weighted_config: str = None,
    liked_features: set = None,
    disliked_features: set = None,
    top_k: int = 5,
    rounds: int = 10,
    output_dir: str = "results",
    sampling_strategy: str = "mixed"
):
    """Run experiment comparing baseline, contrastive, and optionally preference-weighted methods.
    
    Args:
        baseline_config: Path to baseline configuration file
        contrastive_config: Path to contrastive configuration file
        preference_weighted_config: Optional path to preference-weighted config (2025-style); if None, only run baseline vs contrastive
        liked_features: Set of features user likes (default: {"Restaurant"})
        disliked_features: Set of features user dislikes (default: None)
        top_k: Number of top patterns to recommend each round
        rounds: Number of interaction rounds
        output_dir: Directory to save results
        
    Returns:
        Dictionary with baseline, contrastive, and optionally preference_weighted results
    """
    if liked_features is None:
        liked_features = {"Restaurant"}
    
    run_third = preference_weighted_config is not None
    logger.info("=" * 60)
    logger.info("Starting Experiment: Baseline vs Contrastive" + (" vs Preference-Weighted" if run_third else ""))
    logger.info("=" * 60)
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize simulator
    simulator = UserSimulator(
        liked_features=liked_features,
        disliked_features=disliked_features
    )
    
    # Initialize managers
    logger.info("Initializing baseline manager...")
    try:
        baseline_manager = PipelineManager(baseline_config)
        logger.info("Baseline manager initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize baseline manager: {e}", exc_info=True)
        raise
    
    logger.info("Initializing contrastive manager...")
    try:
        contrastive_manager = PipelineManager(contrastive_config)
        logger.info("Contrastive manager initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize contrastive manager: {e}", exc_info=True)
        raise
    
    # Run baseline evaluation
    logger.info("\n" + "=" * 60)
    logger.info("Running Baseline Evaluation")
    logger.info("=" * 60)
    baseline_eval = InteractionEvaluator(
        baseline_manager,
        simulator,
        top_k=top_k,
        rounds=rounds,
        sampling_strategy=sampling_strategy
    )
    baseline_results = baseline_eval.run()
    
    # Run contrastive evaluation
    logger.info("\n" + "=" * 60)
    logger.info("Running Contrastive Evaluation")
    logger.info("=" * 60)
    contrastive_eval = InteractionEvaluator(
        contrastive_manager,
        simulator,
        top_k=top_k,
        rounds=rounds,
        sampling_strategy=sampling_strategy
    )
    contrastive_results = contrastive_eval.run()

    preference_weighted_results = None
    if run_third:
        logger.info("\n" + "=" * 60)
        logger.info("Running Preference-Weighted Evaluation (2025-style adaptive fusion)")
        logger.info("=" * 60)
        try:
            pw_manager = PipelineManager(preference_weighted_config)
            pw_eval = InteractionEvaluator(
                pw_manager,
                simulator,
                top_k=top_k,
                rounds=rounds,
                sampling_strategy=sampling_strategy
            )
            preference_weighted_results = pw_eval.run()
            pw_path = os.path.join(output_dir, "metrics_preference_weighted.json")
            with open(pw_path, 'w') as f:
                json.dump(preference_weighted_results, f, indent=2)
            with open(os.path.join(output_dir, "accuracy_preference_weighted.json"), 'w') as f:
                json.dump(preference_weighted_results['accuracy'], f, indent=2)
            logger.info(f"  Preference-Weighted (all metrics): {pw_path}")
        except Exception as e:
            logger.error(f"Preference-Weighted run failed: {e}", exc_info=True)
    
    # Save results (all metrics)
    baseline_path = os.path.join(output_dir, "metrics_baseline.json")
    contrastive_path = os.path.join(output_dir, "metrics_contrastive.json")
    
    with open(baseline_path, 'w') as f:
        json.dump(baseline_results, f, indent=2)
    
    with open(contrastive_path, 'w') as f:
        json.dump(contrastive_results, f, indent=2)
    
    baseline_acc_path = os.path.join(output_dir, "accuracy_baseline.json")
    contrastive_acc_path = os.path.join(output_dir, "accuracy_contrastive.json")
    
    with open(baseline_acc_path, 'w') as f:
        json.dump(baseline_results['accuracy'], f, indent=2)
    
    with open(contrastive_acc_path, 'w') as f:
        json.dump(contrastive_results['accuracy'], f, indent=2)
    
    logger.info(f"\nResults saved:")
    logger.info(f"  Baseline (all metrics): {baseline_path}")
    logger.info(f"  Contrastive (all metrics): {contrastive_path}")
    logger.info(f"  Baseline (accuracy only): {baseline_acc_path}")
    logger.info(f"  Contrastive (accuracy only): {contrastive_acc_path}")
    
    # Print summary
    logger.info("\n" + "=" * 60)
    logger.info("Experiment Summary")
    logger.info("=" * 60)
    logger.info("Baseline - Final Metrics:")
    logger.info(f"  Accuracy:  {baseline_results['accuracy'][-1]:.4f}")
    logger.info(f"  Precision: {baseline_results['precision'][-1]:.4f}")
    logger.info(f"  Recall:    {baseline_results['recall'][-1]:.4f}")
    logger.info(f"  F1:        {baseline_results['f1'][-1]:.4f}")
    logger.info("\nContrastive - Final Metrics:")
    logger.info(f"  Accuracy:  {contrastive_results['accuracy'][-1]:.4f}")
    logger.info(f"  Precision: {contrastive_results['precision'][-1]:.4f}")
    logger.info(f"  Recall:    {contrastive_results['recall'][-1]:.4f}")
    logger.info(f"  F1:        {contrastive_results['f1'][-1]:.4f}")
    if preference_weighted_results:
        logger.info("\nPreference-Weighted - Final Metrics:")
        logger.info(f"  Accuracy:  {preference_weighted_results['accuracy'][-1]:.4f}")
        logger.info(f"  Precision: {preference_weighted_results['precision'][-1]:.4f}")
        logger.info(f"  Recall:    {preference_weighted_results['recall'][-1]:.4f}")
        logger.info(f"  F1:        {preference_weighted_results['f1'][-1]:.4f}")
    logger.info("\nImprovement (Contrastive vs Baseline):")
    logger.info(f"  Accuracy:  {contrastive_results['accuracy'][-1] - baseline_results['accuracy'][-1]:.4f}")
    logger.info(f"  Precision: {contrastive_results['precision'][-1] - baseline_results['precision'][-1]:.4f}")
    logger.info(f"  Recall:    {contrastive_results['recall'][-1] - baseline_results['recall'][-1]:.4f}")
    logger.info(f"  F1:        {contrastive_results['f1'][-1] - baseline_results['f1'][-1]:.4f}")
    
    out = {'baseline': baseline_results, 'contrastive': contrastive_results}
    if preference_weighted_results is not None:
        out['preference_weighted'] = preference_weighted_results
    return out


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Run preference learning experiment")
    parser.add_argument("--baseline-config", type=str, default="config/config_baseline.yaml",
                       help="Path to baseline configuration file")
    parser.add_argument("--contrastive-config", type=str, default="config/config_contrastive.yaml",
                       help="Path to contrastive configuration file")
    parser.add_argument("--preference-weighted-config", type=str, default=None,
                       help="Path to preference-weighted config (2025-style); if set, run 3-way comparison")
    parser.add_argument("--liked-features", type=str, nargs="+", default=["Restaurant"],
                       help="Features user likes")
    parser.add_argument("--disliked-features", type=str, nargs="+", default=None,
                       help="Features user dislikes")
    parser.add_argument("--top-k", type=int, default=5,
                       help="Number of top patterns to recommend")
    parser.add_argument("--rounds", type=int, default=10,
                       help="Number of interaction rounds")
    parser.add_argument("--output-dir", type=str, default="results",
                       help="Output directory for results")
    parser.add_argument("--sampling-strategy", type=str, default="mixed",
                       choices=["top", "mixed", "stratified", "uncertainty"],
                       help="Pattern sampling strategy: top (original), mixed (recommended), stratified, uncertainty")
    
    args = parser.parse_args()
    
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # Run experiment
    results = run_experiment(
        baseline_config=args.baseline_config,
        contrastive_config=args.contrastive_config,
        preference_weighted_config=args.preference_weighted_config,
        liked_features=set(args.liked_features),
        disliked_features=set(args.disliked_features) if args.disliked_features else None,
        top_k=args.top_k,
        rounds=args.rounds,
        output_dir=args.output_dir,
        sampling_strategy=args.sampling_strategy
    )

