"""API routes for FastAPI application."""

import logging
from fastapi import APIRouter, HTTPException
from typing import Dict, Any

from web.schemas import (
    QueryRequest, QueryResponse,
    FeedbackRequest, FeedbackResponse,
    TrainRequest, TrainResponse,
    TrainingHistoryResponse,
    IterationStartRequest, IterationStartResponse,
    IterationNextRequest, IterationNextResponse,
    IterationFinalizeRequest, IterationFinalizeResponse
)
from web.services import WebService

logger = logging.getLogger(__name__)


def create_router(service: WebService) -> APIRouter:
    """Create API router with all endpoints.
    
    Args:
        service: WebService instance
        
    Returns:
        Configured APIRouter
    """
    router = APIRouter(prefix="/api", tags=["api"])
    
    @router.post("/query", response_model=QueryResponse)
    async def query(request: QueryRequest) -> QueryResponse:
        """Process a natural language query and return patterns.
        
        Args:
            request: Query request with natural language query string and optional iteration rounds
            
        Returns:
            Query response with patterns, rules, and explanation
        """
        try:
            result = service.run_query(request.query, iteration_rounds=request.iteration_rounds)
            return QueryResponse(**result)
        except Exception as e:
            logger.error(f"Error in query endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    @router.post("/feedback", response_model=FeedbackResponse)
    async def feedback(request: FeedbackRequest) -> FeedbackResponse:
        """Save user feedback for a pattern.
        
        Args:
            request: Feedback request with pattern and feedback type
            
        Returns:
            Feedback response with status and counts
        """
        try:
            if request.feedback not in ["positive", "negative"]:
                raise HTTPException(
                    status_code=400, 
                    detail="Feedback must be 'positive' or 'negative'"
                )
            
            result = service.add_feedback(
                pattern_id=request.pattern_id,
                pattern=request.pattern,
                feedback=request.feedback
            )
            return FeedbackResponse(**result)
        except Exception as e:
            logger.error(f"Error in feedback endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    @router.post("/train", response_model=TrainResponse)
    async def train(request: TrainRequest) -> TrainResponse:
        """Trigger preference model training.
        
        Args:
            request: Training request with hyperparameters
            
        Returns:
            Training response with loss history
        """
        try:
            result = service.run_training(
                epochs=request.epochs or 10,
                learning_rate=request.learning_rate or 0.001,
                batch_size=request.batch_size or 32,
                margin=request.margin or 0.3
            )
            return TrainResponse(**result)
        except Exception as e:
            logger.error(f"Error in train endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    @router.get("/train/history", response_model=TrainingHistoryResponse)
    async def get_training_history() -> TrainingHistoryResponse:
        """Get training history.
        
        Returns:
            Training history response with all training records
        """
        try:
            history = service.get_training_history()
            return TrainingHistoryResponse(**history)
        except Exception as e:
            logger.error(f"Error in training history endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    # Stage4: Interactive iteration endpoints
    @router.post("/iteration/start", response_model=IterationStartResponse)
    async def start_iteration(request: IterationStartRequest) -> IterationStartResponse:
        """Start an interactive iteration session (Stage4).
        
        Args:
            request: Iteration start request with query and rounds
            
        Returns:
            Iteration start response with session_id and first round results
        """
        try:
            result = service.start_iteration(request.query, request.iteration_rounds)
            return IterationStartResponse(**result)
        except Exception as e:
            logger.error(f"Error in iteration start endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    @router.post("/iteration/next", response_model=IterationNextResponse)
    async def next_iteration(request: IterationNextRequest) -> IterationNextResponse:
        """Process feedback, train model, and return next iteration round (Stage4).
        
        Args:
            request: Iteration next request with session_id and feedback
            
        Returns:
            Iteration next response with next round results
        """
        try:
            result = service.next_iteration(request.session_id, request.feedback)
            return IterationNextResponse(**result)
        except Exception as e:
            logger.error(f"Error in iteration next endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    @router.post("/iteration/finalize", response_model=IterationFinalizeResponse)
    async def finalize_iteration(request: IterationFinalizeRequest) -> IterationFinalizeResponse:
        """Get final results after all iterations (Stage4).
        
        Args:
            request: Iteration finalize request with session_id
            
        Returns:
            Iteration finalize response with final results
        """
        try:
            result = service.finalize_iteration(request.session_id)
            return IterationFinalizeResponse(**result)
        except Exception as e:
            logger.error(f"Error in iteration finalize endpoint: {e}", exc_info=True)
            raise HTTPException(status_code=500, detail=str(e))
    
    return router

