from typing import List, Union
from llama_cpp import Llama
import numpy as np

class EmbeddingGenerator:
    """
    基于 llama.cpp 的本地 Embedding 生成器
    使用 ggml-org/embeddinggemma-300m-qat-Q8_0.gguf 模型
    """
    
    def __init__(
        self,
        model_path: str = "/root/.node-llama-cpp/models/hf_ggml-org_embeddinggemma-300m-qat-Q8_0.gguf",
        n_ctx: int = 512,
        n_threads: int = 4,
        embedding: bool = True
    ):
        """
        初始化 Embedding 生成器
        
        Args:
            model_path: GGUF 模型路径
            n_ctx: 上下文长度
            n_threads: 使用的线程数
            embedding: 是否启用 embedding 模式
        """
        self.model_path = model_path
        self.n_ctx = n_ctx
        self.n_threads = n_threads
        
        # 加载模型
        print(f"正在加载 Embedding 模型: {model_path}")
        self.llm = Llama(
            model_path=model_path,
            n_ctx=n_ctx,
            n_threads=n_threads,
            embedding=embedding,
            verbose=False  # 关闭详细日志
        )
        self.embedding_dim = 768  # embeddinggemma-300m 的维度
        print(f"✅ 模型加载完成，向量维度: {self.embedding_dim}")
    
    def encode(
        self,
        texts: Union[str, List[str]],
        normalize: bool = True
    ) -> np.ndarray:
        """
        将文本编码为向量
        
        Args:
            texts: 单个文本或文本列表
            normalize: 是否归一化向量（L2范数）
        
        Returns:
            向量数组 (N, 768)
        """
        # 统一转为列表
        if isinstance(texts, str):
            texts = [texts]
        
        embeddings = []
        
        for text in texts:
            # 截断超长文本
            if len(text) > self.n_ctx * 4:  # 粗略估计字符数
                text = text[:self.n_ctx * 4]
            
            # 生成 embedding
            output = self.llm.embed(text)
            embeddings.append(output)
        
        # 转为 numpy 数组
        embeddings = np.array(embeddings)
        
        # L2 归一化（便于计算余弦相似度）
        if normalize:
            embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        
        return embeddings
    
    def similarity(
        self,
        text1: str,
        text2: str
    ) -> float:
        """
        计算两个文本的相似度（余弦相似度）
        
        Returns:
            相似度分数 (-1 ~ 1)
        """
        emb1,emb2 = self.encode([text1, text2], normalize=True)
        # 余弦相似度 = 点积（已归一化）
        return float(np.dot(emb1, emb2))
    
    def batch_encode(
        self,
        texts: List[str],
        batch_size: int = 8
    ) -> np.ndarray:
        """
        批量编码（带进度显示）
        
        Args:
            texts: 文本列表
            batch_size: 批次大小
        
        Returns:
            向量数组 (N, 768)
        """
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_embeddings = self.encode(batch, normalize=True)
            all_embeddings.append(batch_embeddings)
            
            if (i // batch_size + 1) % 10 == 0:
                print(f"已处理: {min(i + batch_size, len(texts))}/{len(texts)}")
        
        return np.vstack(all_embeddings)
    
    def __del__(self):
        """清理资源"""
        if hasattr(self, 'llm'):
            del self.llm

# ==================== 使用示例 ====================

if __name__ == "__main__":
    # 1.初始化生成器
    embedder = EmbeddingGenerator(
        model_path="/root/.node-llama-cpp/models/hf_ggml-org_embeddinggemma-300m-qat-Q8_0.gguf",
        n_threads=4
    )



# 2. 单文本编码
text = "OpenClaw 是一个 AI 智能体框架"
vector = embedder.encode(text)
print(f"文本: {text}")
print(f"向量维度: {vector.shape}")  # (1, 768)
print(f"向量前5个值: {vector[0, :5]}")

# 3. 批量编码
texts = [
    "OpenClaw 智能体框架",
    "QQ 机器人开发",
    "桂林电子科技大学",
    "金丝熊仓鼠喜欢吃香蕉干"
]
vectors = embedder.encode(texts)
print(f"\n批量编码: {vectors.shape}")  # (4, 768)

# 4. 计算文本相似度
sim = embedder.similarity(
    "OpenClaw 是 AI 框架",
    "OpenClaw 用于开发智能体"
)
print(f"\n相似度: {sim:.4f}")

# 5. 批量相似度搜索
query = "AI 智能体开发"
documents = [
    "OpenClaw是一个 AI 智能体框架",
        "桂林米粉是桂林的特色美食",
        "Python 是一种编程语言",
        "AI 智能体可以自动执行任务"
    ]



query_vec = embedder.encode(query)
doc_vecs = embedder.encode(documents)

# 计算余弦相似度
similarities = np.dot(doc_vecs, query_vec.T).flatten()

print(f"\n查询: {query}")
print("最相关文档:")
for idx in np.argsort(similarities)[::-1][:3]:
    print(f"  [{similarities[idx]:.4f}] {documents[idx]}")