"""Intent encoder for understanding user goals using LLM."""

import json
import logging
from typing import Dict, Any, Optional

from llm.intent_prompt import IntentPrompt
from llm.client import LLMClient

logger = logging.getLogger(__name__)


class IntentEncoder:
    """Encoder for understanding user intent using LLM."""
    
    def __init__(self, llm_client: LLMClient, available_poi_types: Optional[list] = None):
        """Initialize intent encoder.
        
        Args:
            llm_client: LLM client instance
            available_poi_types: List of actual POI types in the dataset
        """
        self.llm = llm_client
        self.available_poi_types = available_poi_types or []
        self.prompt_template = IntentPrompt.INTENT_PROMPT
    
    def parse(self, query: str) -> Optional[Dict[str, Any]]:
        """Parse user query to extract intent.
        
        This method implements Stage0 intent understanding:
        1. User ambiguous query (e.g., "我想开个烧烤店")
        2. LLM analyzes query and extracts structured intent
        3. Returns structured intent with business type, pattern preferences, etc.
        
        Args:
            query: User's natural language query
            
        Returns:
            Dictionary containing parsed intent, or None if parsing fails
        """
        try:
            logger.info(f"IntentEncoder: Starting intent understanding for query: {query[:100]}...")
            
            # Step 1: Build prompt with actual POI types if available
            if self.available_poi_types:
                logger.info(f"IntentEncoder: Building prompt with {len(self.available_poi_types)} available POI types")
                prompt = IntentPrompt.build_intent_prompt(query, self.available_poi_types)
            else:
                logger.info("IntentEncoder: Building prompt without POI type constraints")
                prompt = self.prompt_template.format(query=query)
            
            # Step 2: Call LLM for intent understanding
            logger.info("IntentEncoder: Calling LLM for intent understanding...")
            system_msg = "You are an expert in spatial business analysis."
            raw_response = self.llm.generate(prompt, system=system_msg)
            logger.debug(f"IntentEncoder: LLM raw response length: {len(raw_response)}")
            
            # Step 3: Extract JSON from response
            logger.info("IntentEncoder: Extracting structured intent from LLM response...")
            intent_json = self._extract_json(raw_response)
            
            if intent_json is None:
                logger.warning(f"IntentEncoder: Failed to extract valid JSON from LLM response: {raw_response[:200]}")
                return None
            
            logger.info(f"IntentEncoder: Extracted intent JSON with keys: {list(intent_json.keys())}")
            
            # Step 4: Validate and filter pattern_preference to only include valid POI types
            if self.available_poi_types:
                logger.info("IntentEncoder: Validating patterns against dataset POI types...")
                original_count = len(intent_json.get("pattern_preference", []))
                intent_json = self._validate_patterns(intent_json)
                validated_count = len(intent_json.get("pattern_preference", []))
                logger.info(f"IntentEncoder: Pattern validation: {original_count} → {validated_count} valid patterns")
            
            # Log extracted intent summary
            business = intent_json.get("business", "N/A")
            pattern_count = len(intent_json.get("pattern_preference", []))
            logger.info(f"IntentEncoder: Successfully parsed intent - Business: {business}, Patterns: {pattern_count}")
            logger.info(f"IntentEncoder: Preferred patterns: {intent_json.get('pattern_preference', [])}")
            
            return intent_json
            
        except Exception as e:
            logger.error(f"IntentEncoder: Error parsing intent: {e}", exc_info=True)
            return None
    
    def _validate_patterns(self, intent_json: Dict[str, Any]) -> Dict[str, Any]:
        """Validate and filter pattern_preference to only include valid POI types.
        
        Args:
            intent_json: Parsed intent JSON
            
        Returns:
            Validated intent JSON with filtered patterns
        """
        if 'pattern_preference' not in intent_json:
            return intent_json
        
        valid_patterns = []
        poi_set = set(self.available_poi_types)
        
        for pattern in intent_json['pattern_preference']:
            if not isinstance(pattern, list):
                continue
            
            # Filter pattern to only include valid POI types
            valid_pois = [poi for poi in pattern if isinstance(poi, str) and poi in poi_set]
            
            # Only add patterns that have at least 2 valid POI types
            if len(valid_pois) >= 2:
                valid_patterns.append(valid_pois)
            elif len(valid_pois) == 1:
                # If only one POI type, try to find a complementary one
                # For now, just skip single POI patterns
                logger.debug(f"Skipping single POI pattern: {valid_pois}")
        
        intent_json['pattern_preference'] = valid_patterns
        
        if len(valid_patterns) == 0:
            logger.warning("No valid patterns found after filtering. Original patterns may have used POI types not in dataset.")
        
        return intent_json
    
    def _extract_json(self, text: str) -> Optional[Dict[str, Any]]:
        """Extract JSON object from LLM response.
        
        Args:
            text: Raw LLM response text
            
        Returns:
            Parsed JSON dictionary, or None if extraction fails
        """
        try:
            # Try to find JSON object in the response
            # Look for the first complete JSON object
            start_idx = text.find('{')
            if start_idx == -1:
                return None
            
            # Find matching closing brace
            brace_count = 0
            end_idx = start_idx
            
            for i in range(start_idx, len(text)):
                if text[i] == '{':
                    brace_count += 1
                elif text[i] == '}':
                    brace_count -= 1
                    if brace_count == 0:
                        end_idx = i
                        break
            
            if brace_count != 0:
                logger.warning("Unmatched braces in JSON response")
                return None
            
            json_str = text[start_idx:end_idx + 1]
            intent_data = json.loads(json_str)
            
            # Validate required fields
            if not isinstance(intent_data, dict):
                return None
            
            # Ensure pattern_preference is a list of lists
            if 'pattern_preference' in intent_data:
                if not isinstance(intent_data['pattern_preference'], list):
                    intent_data['pattern_preference'] = []
                else:
                    # Ensure each element is a list
                    intent_data['pattern_preference'] = [
                        p if isinstance(p, list) else [p] if isinstance(p, str) else []
                        for p in intent_data['pattern_preference']
                    ]
            
            return intent_data
            
        except json.JSONDecodeError as e:
            logger.warning(f"JSON decode error: {e}")
            return None
        except Exception as e:
            logger.error(f"Error extracting JSON: {e}")
            return None

