"""LLM client for loading and running local fine-tuned models."""

import logging
import os
from typing import Optional

import torch
import yaml
from transformers import AutoTokenizer, TextIteratorStreamer
from threading import Thread

logger = logging.getLogger(__name__)


class LLMClient:
    """Client for loading and running local LLM models."""
    
    def __init__(self, config_path: str = "config/config.yaml"):
        """Initialize LLM client with configuration.
        
        Args:
            config_path: Path to configuration file
        """
        with open(config_path, 'r', encoding='utf-8') as f:
            self.config = yaml.safe_load(f)
        
        self.model = None
        self.tokenizer = None
        self.device = self.config['inference']['device']
        self.template = self.config['model']['template']
        self._load_model()
    
    def _load_model(self):
        """Load the base model and optionally LoRA adapter."""
        try:
            # 尝试导入 LlamaFactory
            import sys
            llamafactory_path = "/home/ubuntu/codebase/yexijia/保研/LlamaFactory/src"
            if llamafactory_path not in sys.path:
                sys.path.insert(0, llamafactory_path)
            
            from llamafactory.chat import ChatModel
            
            model_config = self.config['model']
            
            # 构建参数字典
            args = {
                'model_name_or_path': model_config['model_name_or_path'],
                'template': model_config['template'],
                'finetuning_type': model_config.get('finetuning_type', 'none'),
                'trust_remote_code': model_config.get('trust_remote_code', True),
            }
            
            # 如果指定了适配器路径
            if model_config.get('adapter_name_or_path'):
                args['adapter_name_or_path'] = model_config['adapter_name_or_path']
            
            # 使用 LlamaFactory 的方式加载模型
            self.chat_model = ChatModel(args)
            logger.info("Model loaded successfully with LlamaFactory")
            
        except (ImportError, ModuleNotFoundError) as e:
            logger.warning(f"LlamaFactory not available: {e}, using fallback method")
            self._load_model_fallback()
        except Exception as e:
            logger.error(f"Failed to load model with LlamaFactory: {e}")
            logger.info("Falling back to direct transformers loading")
            self._load_model_fallback()
    
    def _load_model_fallback(self):
        """Fallback method using transformers directly."""
        from transformers import AutoModelForCausalLM
        
        model_path = self.config['model']['model_name_or_path']
        
        # 加载 tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True
        )
        
        # 加载模型
        torch_dtype = torch.float16 if self.config['inference']['use_fp16'] else torch.float32
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
            device_map="auto" if self.device == "cuda" else None,
            trust_remote_code=True
        )
        
        if self.device == "cpu":
            self.model = self.model.to(self.device)
        
        # 设置 pad_token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        logger.info("Model loaded using fallback method")
    
    def generate(self, prompt: str, system: Optional[str] = None) -> str:
        """Generate response from the model.
        
        Args:
            prompt: Input prompt text
            system: Optional system message
            
        Returns:
            Generated text response
        """
        try:
            # 使用 LlamaFactory 的 ChatModel
            if hasattr(self, 'chat_model'):
                # 每次调用都创建全新的消息列表，确保不保留历史
                # 注意：虽然我们每次都创建新的 messages，但 LLM 可能在生成时
                # 包含历史对话（这是模型本身的行为，不是我们的代码问题）
                messages = []
                if system:
                    messages.append({"role": "system", "content": system})
                messages.append({"role": "user", "content": prompt})
                
                # 调用 chat 方法，传入全新的 messages 列表
                # 注意：LlamaFactory 的 chat 方法应该不会维护历史，但如果模板系统有问题，
                # 我们可以通过每次创建新的 ChatModel 实例来解决
                responses = self.chat_model.chat(messages)
                if responses:
                    response_text = responses[0].response_text
                    # 清理响应：如果响应包含之前的对话历史，尝试提取最新的 JSON 响应
                    # 问题：LlamaFactory 的 ChatModel 可能在某些情况下会包含历史对话
                    # 解决方案：提取最后一个完整的 JSON 对象
                    if "{" in response_text and "}" in response_text:
                        # 找到所有 JSON 对象的开始和结束位置
                        json_objects = []
                        start = 0
                        while True:
                            start_brace = response_text.find("{", start)
                            if start_brace == -1:
                                break
                            # 找到匹配的结束括号
                            brace_count = 0
                            end_brace = start_brace
                            for i in range(start_brace, len(response_text)):
                                if response_text[i] == "{":
                                    brace_count += 1
                                elif response_text[i] == "}":
                                    brace_count -= 1
                                    if brace_count == 0:
                                        end_brace = i
                                        break
                            if brace_count == 0:
                                json_objects.append((start_brace, end_brace))
                                start = end_brace + 1
                            else:
                                break
                        
                        # 如果有多个 JSON 对象，取第一个（通常是最新的正确响应）
                        # 注意：LLM 可能在生成时包含历史对话，第一个 JSON 通常是最新的响应
                        if json_objects:
                            first_start, first_end = json_objects[0]
                            # 验证第一个 JSON 是否有效
                            try:
                                import json
                                json_text = response_text[first_start:first_end+1]
                                json.loads(json_text)  # 验证 JSON 有效性
                                response_text = json_text
                            except json.JSONDecodeError:
                                # 如果第一个 JSON 无效，尝试最后一个
                                if len(json_objects) > 1:
                                    last_start, last_end = json_objects[-1]
                                    response_text = response_text[last_start:last_end+1]
                        else:
                            # 如果找不到完整的 JSON，尝试提取第一个可能的 JSON
                            first_brace = response_text.find("{")
                            if first_brace != -1:
                                # 尝试找到匹配的结束括号
                                brace_count = 0
                                for i in range(first_brace, len(response_text)):
                                    if response_text[i] == "{":
                                        brace_count += 1
                                    elif response_text[i] == "}":
                                        brace_count -= 1
                                        if brace_count == 0:
                                            response_text = response_text[first_brace:i+1]
                                            break
                    
                    return response_text
                return ""
            
            # Fallback 方法
            else:
                return self._generate_fallback(prompt, system)
                
        except Exception as e:
            logger.error(f"Generation error: {e}")
            return f"Error: {str(e)}"
    
    def _generate_fallback(self, prompt: str, system: Optional[str] = None) -> str:
        """Fallback generation method."""
        # 构建完整提示
        if system:
            full_prompt = f"{system}\n\n{prompt}"
        else:
            full_prompt = prompt
        
        # Tokenize
        inputs = self.tokenizer.encode(full_prompt, return_tensors="pt")
        if self.device == "cuda":
            inputs = inputs.to(self.device)
        
        # 生成参数
        infer_config = self.config['inference']
        generation_config = {
            'max_new_tokens': infer_config['max_new_tokens'],
            'temperature': infer_config['temperature'],
            'top_p': infer_config['top_p'],
            'do_sample': infer_config['do_sample'],
            'pad_token_id': self.tokenizer.pad_token_id,
            'eos_token_id': self.tokenizer.eos_token_id,
        }
        
        # 生成
        with torch.no_grad():
            outputs = self.model.generate(inputs, **generation_config)
        
        # 解码（只返回新生成的部分）
        generated_text = self.tokenizer.decode(
            outputs[0][inputs.shape[1]:], 
            skip_special_tokens=True
        )
        
        return generated_text.strip()
    
    def load_model(self):
        """Alias for _load_model for backward compatibility."""
        self._load_model()

