"""Pydantic schemas for API request/response validation."""

from pydantic import BaseModel
from typing import List, Dict, Any, Optional


class QueryRequest(BaseModel):
    """Request schema for query endpoint."""
    query: str
    iteration_rounds: Optional[int] = None  # Stage4: Number of iteration rounds


class QueryResponse(BaseModel):
    """Response schema for query endpoint."""
    params: Dict[str, Any]
    patterns: List[Dict[str, Any]]
    rules: List[Dict[str, Any]]
    explanation: str
    pattern_count: int
    rule_count: int
    re_ranked: bool = False
    similarity_scores: Optional[Dict[str, float]] = None
    intent_used: bool = False  # Stage0: Whether intent understanding was used
    intent_data: Optional[Dict[str, Any]] = None  # Stage0: Intent data from LLM
    total_rounds: Optional[int] = None  # Stage4: Total iteration rounds
    iteration_history: Optional[List[Dict[str, Any]]] = None  # Stage4: Iteration history


class FeedbackRequest(BaseModel):
    """Request schema for feedback endpoint."""
    pattern_id: int
    pattern: List[str]  # 模式内容，例如 ["Park", "Museum"]
    feedback: str  # "positive" 或 "negative"


class FeedbackResponse(BaseModel):
    """Response schema for feedback endpoint."""
    status: str
    message: str
    positive_count: int
    negative_count: int


class TrainRequest(BaseModel):
    """Request schema for train endpoint."""
    epochs: Optional[int] = 10
    learning_rate: Optional[float] = 0.001
    batch_size: Optional[int] = 32
    margin: Optional[float] = 0.3


class TrainResponse(BaseModel):
    """Response schema for train endpoint."""
    status: str
    message: str
    loss_history: List[float]
    epochs: int


class TrainingHistoryResponse(BaseModel):
    """Response schema for training history endpoint."""
    trainings: List[Dict[str, Any]]


# Stage4: Iterative interaction schemas
class IterationStartRequest(BaseModel):
    """Request schema for starting iteration."""
    query: str
    iteration_rounds: int


class IterationStartResponse(BaseModel):
    """Response schema for iteration start."""
    session_id: str
    round: int
    total_rounds: int
    patterns: List[Dict[str, Any]]
    rules: List[Dict[str, Any]]
    similarity_scores: Optional[Dict[str, float]] = None
    is_final: bool = False
    intent_used: bool = False  # Stage0: Whether intent understanding was used
    intent_data: Optional[Dict[str, Any]] = None  # Stage0: Intent data from LLM
    params: Optional[Dict[str, Any]] = None  # Mining parameters


class IterationNextRequest(BaseModel):
    """Request schema for next iteration step."""
    session_id: str
    feedback: Dict[str, List[str]]  # {"positive": [...], "negative": [...]}


class IterationNextResponse(BaseModel):
    """Response schema for next iteration step."""
    round: int
    total_rounds: int
    patterns: List[Dict[str, Any]]
    rules: List[Dict[str, Any]]
    similarity_scores: Optional[Dict[str, float]] = None
    training_result: Optional[Dict[str, Any]] = None
    is_final: bool = False
    intent_used: bool = False  # Stage0: Whether intent understanding was used
    intent_data: Optional[Dict[str, Any]] = None  # Stage0: Intent data from LLM


class IterationFinalizeRequest(BaseModel):
    """Request schema for finalizing iteration."""
    session_id: str


class IterationFinalizeResponse(BaseModel):
    """Response schema for iteration finalization."""
    patterns: List[Dict[str, Any]]
    rules: List[Dict[str, Any]]
    similarity_scores: Optional[Dict[str, float]] = None
    explanation: str
    iteration_history: List[Dict[str, Any]]
    intent_used: bool = False  # Stage0: Whether intent understanding was used
    intent_data: Optional[Dict[str, Any]] = None  # Stage0: Intent data from LLM
    params: Optional[Dict[str, Any]] = None  # Mining parameters


class IterationStepRequest(BaseModel):
    """Request schema for iteration step endpoint (Stage4)."""
    query: str
    current_round: int
    total_rounds: int
    mining_params: Optional[Dict[str, Any]] = None


class IterationStepResponse(BaseModel):
    """Response schema for iteration step endpoint (Stage4)."""
    round: int
    patterns: List[Dict[str, Any]]
    rules: List[Dict[str, Any]]
    total_patterns: int
    total_rules: int
    similarity_scores: Optional[Dict[str, float]] = None
    user_vector_norm: float
    is_final: bool
    params: Dict[str, Any]


class IterationFeedbackRequest(BaseModel):
    """Request schema for iteration feedback endpoint (Stage4)."""
    positive_patterns: List[str]  # List of pattern keys (comma-separated)
    negative_patterns: List[str]   # List of pattern keys (comma-separated)
    training_epochs: Optional[int] = 5  # Number of epochs for quick training


class IterationFeedbackResponse(BaseModel):
    """Response schema for iteration feedback endpoint (Stage4)."""
    status: str
    user_vector_norm: float
    training: Dict[str, Any]

