"""Preference parser for converting LLM output into structured parameters."""

import json
import logging
import re
from typing import Dict, Any

logger = logging.getLogger(__name__)


class PreferenceParser:
    """Parser for extracting structured parameters from LLM responses."""
    
    DEFAULT_PARAMS = {
        "min_participation": 0.6,
        "max_pattern_size": 5,
        "priority": "confidence",
        "min_confidence": 0.5  # 关联规则的最小置信度
    }
    
    def __init__(self, default_params: Dict[str, Any] = None):
        """Initialize parser with default parameters.
        
        Args:
            default_params: Default parameter values
        """
        self.default_params = self._normalize_default_params(default_params)

    def _normalize_default_params(self, default_params: Dict[str, Any] = None) -> Dict[str, Any]:
        """Normalize config defaults to parser schema keys."""
        normalized = self.DEFAULT_PARAMS.copy()
        if not default_params:
            return normalized

        key_mapping = {
            "default_min_participation": "min_participation",
            "default_max_pattern_size": "max_pattern_size",
            "default_priority": "priority",
            "default_min_confidence": "min_confidence",
        }

        for key, value in default_params.items():
            target_key = key_mapping.get(key, key)
            normalized[target_key] = value

        return normalized
    
    def parse_preference(self, text: str) -> Dict[str, Any]:
        """Parse LLM output into structured parameters.
        
        Args:
            text: Raw LLM response text
            
        Returns:
            Dictionary with parsed parameters
        """
        # 尝试提取 JSON
        json_match = self._extract_json(text)
        if json_match:
            try:
                params = json.loads(json_match)
                return self._validate_and_fill_defaults(params)
            except json.JSONDecodeError as e:
                logger.warning(f"JSON decode error: {e}")
        
        # 如果 JSON 提取失败，尝试正则表达式提取
        params = self._extract_with_regex(text)
        return self._validate_and_fill_defaults(params)
    
    def _extract_json(self, text: str) -> str:
        """Extract JSON from text.
        
        Args:
            text: Input text
            
        Returns:
            Extracted JSON string or None
        """
        # 尝试找到 JSON 对象
        json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
        matches = re.findall(json_pattern, text, re.DOTALL)
        
        for match in matches:
            try:
                # 验证是否是有效的 JSON
                json.loads(match)
                return match
            except json.JSONDecodeError:
                continue
        
        return None
    
    def _extract_with_regex(self, text: str) -> Dict[str, Any]:
        """Extract parameters using regex patterns.
        
        Args:
            text: Input text
            
        Returns:
            Dictionary with extracted parameters
        """
        params = {}
        
        # 提取 min_participation
        participation_patterns = [
            r'min[_\s]*participation[:\s]*([0-9.]+)',
            r'参与率[:\s]*([0-9.]+)',
            r'participation[:\s]*([0-9.]+)',
        ]
        for pattern in participation_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                try:
                    params['min_participation'] = float(match.group(1))
                    break
                except ValueError:
                    continue
        
        # 提取 max_pattern_size
        size_patterns = [
            r'max[_\s]*pattern[_\s]*size[:\s]*(\d+)',
            r'模式大小[:\s]*(\d+)',
            r'pattern[_\s]*size[:\s]*(\d+)',
            r'(\d+)[阶个]',
        ]
        for pattern in size_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                try:
                    params['max_pattern_size'] = int(match.group(1))
                    break
                except ValueError:
                    continue
        
        # 提取 priority
        priority_patterns = [
            r'priority[:\s]*["\']?(\w+)["\']?',
            r'优先级[:\s]*["\']?(\w+)["\']?',
        ]
        priority_keywords = {
            'confidence': ['confidence', '置信度', 'conf'],
            'participation': ['participation', '参与率', 'part'],
            'size': ['size', '大小', '规模'],
        }
        
        for pattern in priority_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                value = match.group(1).lower()
                for key, keywords in priority_keywords.items():
                    if any(kw in value for kw in keywords):
                        params['priority'] = key
                        break
        
        return params
    
    def _validate_and_fill_defaults(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """Validate parameters and fill missing ones with defaults.
        
        Args:
            params: Parsed parameters
            
        Returns:
            Validated and complete parameter dictionary
        """
        result = self.default_params.copy()
        
        # 验证并更新 min_participation
        if 'min_participation' in params:
            try:
                val = float(params['min_participation'])
                if 0.0 <= val <= 1.0:
                    result['min_participation'] = val
                else:
                    logger.warning(f"min_participation out of range: {val}")
            except (ValueError, TypeError):
                logger.warning(f"Invalid min_participation: {params['min_participation']}")
        
        # 验证并更新 max_pattern_size
        if 'max_pattern_size' in params:
            try:
                val = int(params['max_pattern_size'])
                if 1 <= val <= 20:
                    result['max_pattern_size'] = val
                else:
                    logger.warning(f"max_pattern_size out of range: {val}")
            except (ValueError, TypeError):
                logger.warning(f"Invalid max_pattern_size: {params['max_pattern_size']}")
        
        # 验证并更新 priority
        if 'priority' in params:
            val = str(params['priority']).lower()
            valid_priorities = ['confidence', 'participation', 'size']
            if val in valid_priorities:
                result['priority'] = val
            else:
                logger.warning(f"Invalid priority: {val}")
        
        return result

