"""
RAG (Retrieval-Augmented Generation) 核心实现
包含：文档分块、Embedding生成、向量相似度匹配、sqlite-vec查询
"""

import sqlite3
import numpy as np
from typing import List, Tuple, Dict
import json
import hashlib


class DocumentChunk:
    """文档片段"""
    def __init__(self, text: str, source: str, chunk_id: str, embedding: np.ndarray = None):
        self.text = text              # 文本内容
        self.source = source          # 来源文件
        self.chunk_id = chunk_id      # 唯一ID
        self.embedding = embedding    # 向量表示
    
    def __repr__(self):
        return f"Chunk(id={self.chunk_id}, text_len={len(self.text)}, source={self.source})"


class TextChunker:
    """
    文档分块器
    将长文本切分成适合检索的小片段
    """
    
    def __init__(self, chunk_size: int = 500, overlap: int = 50):
        """
        Args:
            chunk_size: 每个块的最大字符数
            overlap: 相邻块之间的重叠字符数（保持上下文连贯）
        """
        self.chunk_size = chunk_size
        self.overlap = overlap
    
    def chunk_by_paragraph(self, text: str, source: str) -> List[DocumentChunk]:
        """
        按段落分块
        优先保证段落完整性，如果段落太长再切分
        """
        chunks = []
        paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
        
        for i, para in enumerate(paragraphs):
            if len(para) <= self.chunk_size:
                # 段落较短，直接作为一个块
                chunk_id = hashlib.md5(f"{source}:{i}:{para[:50]}".encode()).hexdigest()[:12]
                chunks.append(DocumentChunk(para, source, chunk_id))
            else:
                # 段落太长，需要进一步切分
                sub_chunks = self._split_long_text(para, source, i)
                chunks.extend(sub_chunks)
        
        return chunks
    
    def _split_long_text(self, text: str, source: str, para_idx: int) -> List[DocumentChunk]:
        """将长文本按固定大小切分，带重叠"""
        chunks = []
        start = 0
        chunk_idx = 0
        
        while start < len(text):
            end = min(start + self.chunk_size, len(text))
            
            # 尽量在句子边界切分
            if end < len(text):
                # 找最近的句号、问号、换行
                for sep in ['。', '？', '!', '\n', ' ']:
                    pos = text.rfind(sep, start, end)
                    if pos > start + self.chunk_size // 2:  # 至少保留一半内容
                        end = pos + 1
                        break
            
            chunk_text = text[start:end].strip()
            if chunk_text:
                chunk_id = hashlib.md5(f"{source}:{para_idx}:{chunk_idx}:{chunk_text[:30]}".encode()).hexdigest()[:12]
                chunks.append(DocumentChunk(chunk_text, source, chunk_id))
            
            # 下一个块的起始位置（带重叠）
            start = end - self.overlap if end < len(text) else end
            chunk_idx += 1
        
        return chunks


class SimpleEmbedding:
    """
    简单的 Embedding 生成器
    实际项目中应该使用专业的模型（如 embeddinggemma、OpenAI 等）
    这里用简化版演示原理
    """
    
    def __init__(self, dim: int = 768):
        self.dim = dim
        # 简单的词表（实际应该用 tokenizer）
        self.vocab = {}
    
    def _tokenize(self, text: str) -> List[str]:
        """简单分词（按字符）"""
        return list(text.lower())
    
    def embed(self, text: str) -> np.ndarray:
        """
        生成文本的向量表示
        简化版：基于字符频率的哈希向量
        实际应该用预训练模型
        """
        tokens = self._tokenize(text)
        vector = np.zeros(self.dim)
        
        for token in tokens:
            # 用哈希确定位置
            idx = hash(token) % self.dim
            vector[idx] += 1
        
        # L2 归一化
        norm = np.linalg.norm(vector)
        if norm > 0:
            vector = vector / norm
        
        return vector
    
    def embed_batch(self, texts: List[str]) -> np.ndarray:
        """批量生成向量"""
        return np.array([self.embed(t) for t in texts])


class VectorStore:
    """
    向量存储与检索
    使用 sqlite-vec 或简单的 numpy 实现
    """
    
    def __init__(self, db_path: str = ":memory:", dim: int = 768):
        self.db_path = db_path
        self.dim = dim
        self.conn = None
        self._init_db()
    
    def _init_db(self):
        """初始化数据库"""
        self.conn = sqlite3.connect(self.db_path)
        
        # 创建表
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS chunks (
                chunk_id TEXT PRIMARY KEY,
                text TEXT NOT NULL,
                source TEXT NOT NULL,
                embedding BLOB NOT NULL
            )
        """)
        
        # 创建虚拟表用于向量搜索（需要 sqlite-vec 扩展）
        # 如果没有扩展，我们用纯 Python 实现
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS vectors (
                chunk_id TEXT PRIMARY KEY,
                vector_json TEXT NOT NULL,
                FOREIGN KEY (chunk_id) REFERENCES chunks(chunk_id)
            )
        """)
        
        self.conn.commit()
    
    def add_chunk(self, chunk: DocumentChunk):
        """添加文档片段"""
        # 序列化向量
        embedding_bytes = chunk.embedding.tobytes()
        vector_json = json.dumps(chunk.embedding.tolist())
        
        self.conn.execute(
            "INSERT OR REPLACE INTO chunks (chunk_id, text, source, embedding) VALUES (?, ?, ?, ?)",
            (chunk.chunk_id, chunk.text, chunk.source, embedding_bytes)
        )
        
        self.conn.execute(
            "INSERT OR REPLACE INTO vectors (chunk_id, vector_json) VALUES (?, ?)",
            (chunk.chunk_id, vector_json)
        )
        
        self.conn.commit()
    
    def add_chunks(self, chunks: List[DocumentChunk]):
        """批量添加"""
        for chunk in chunks:
            self.add_chunk(chunk)
    
    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        """计算余弦相似度"""
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
        """
        向量相似度搜索
        返回: [(chunk_id, text, similarity), ...]
        """
        cursor = self.conn.execute("SELECT chunk_id, text, vector_json FROM chunks JOIN vectors USING (chunk_id)")
        
        results = []
        for row in cursor:
            chunk_id, text, vector_json = row
            stored_vector = np.array(json.loads(vector_json))
            
            # 计算相似度
            similarity = self._cosine_similarity(query_embedding, stored_vector)
            results.append((chunk_id, text, similarity))
        
        # 按相似度排序，返回 top_k
        results.sort(key=lambda x: x[2], reverse=True)
        return results[:top_k]


class RAGSystem:
    """
    完整的 RAG 系统
    """
    
    def __init__(self, db_path: str = ":memory:"):
        self.chunker = TextChunker(chunk_size=300, overlap=30)
        self.embedder = SimpleEmbedding(dim=768)
        self.store = VectorStore(db_path=db_path, dim=768)
    
    def add_document(self, text: str, source: str):
        """添加文档到知识库"""
        # 1. 分块
        chunks = self.chunker.chunk_by_paragraph(text, source)
        print(f"📄 {source}: 分成 {len(chunks)} 个片段")
        
        # 2. 生成向量
        for chunk in chunks:
            chunk.embedding = self.embedder.embed(chunk.text)
        
        # 3. 存储
        self.store.add_chunks(chunks)
        print(f"✅ 已存储 {len(chunks)} 个向量")
    
    def query(self, question: str, top_k: int = 3) -> List[Tuple[str, float]]:
        """
        检索相关文档
        返回: [(text, similarity), ...]
        """
        # 1. 问题向量化
        query_embedding = self.embedder.embed(question)
        
        # 2. 向量检索
        results = self.store.search(query_embedding, top_k=top_k)
        
        # 3. 格式化返回
        return [(text, score) for _, text, score in results]
    
    def answer(self, question: str, context_template: str = None) -> str:
        """
        生成回答（简化版，实际应该调用 LLM）
        """
        # 检索相关文档
        contexts = self.query(question, top_k=3)
        
        if not contexts:
            return "❌ 未找到相关信息"
        
        # 组装上下文
        context_text = "\n\n".join([f"[相关度: {score:.3f}] {text}" for text, score in contexts])
        
        # 简化版回答（实际应该调用 LLM）
        answer = f"""🤖 基于检索到的信息：

{context_text}

---
💡 说明：实际项目中，这里会把上述上下文传给 LLM，让 LLM 生成最终回答。
"""
        return answer


# ==================== 使用示例 ====================

if __name__ == "__main__":
    # 初始化 RAG 系统
    rag = RAGSystem(db_path="rag_demo.db")
    
    # 添加示例文档
    doc1 = """
    第二课堂成绩分为六个类别：思想政治与道德修养、科技学术与创新创业、
    文化艺术与体育、实践实习与社会工作、技能特长、其他课程。
    
    学生每学年需要获得 "思想政治与道德修养" 类学分不少于 15 分，
    "科技学术与创新创业""文化艺术与体育""实践实习与社会工作" 类学分各不少于 5 分。
    """
    
    doc2 = """
    志愿服务时长与第二课堂学分挂钩：学生在 "桂志愿" 平台每学年累计志愿服务时长 
    10 小时可获得 1 分第二课堂学分，由学生自主申请，学院团委统一认定。
    
    志愿者星级评定：20 小时为一星、50 小时为二星、80 小时为三星、
    110 小时为四星、140 小时为五星志愿者。
    """
    
    rag.add_document(doc1, "第二课堂制度")
    rag.add_document(doc2, "志愿服务管理办法")
    
    # 测试查询
    print("\n" + "="*50)
    print("🔍 测试查询：第二课堂有哪些类别？")
    print("="*50)
    answer = rag.answer("第二课堂有哪些类别？")
    print(answer)
    
    print("\n" + "="*50)
    print("🔍 测试查询：志愿服务时长怎么算学分？")
    print("="*50)
    answer = rag.answer("志愿服务时长怎么算学分？")
    print(answer)
    
    print("\n" + "="*50)
    print("🔍 测试查询：五星志愿者需要多少小时？")
    print("="*50)
    answer = rag.answer("五星志愿者需要多少小时？")
    print(answer)
