# -*- coding: utf-8 -*- """ 向量存储 - 使用 Ollama 生成向量,FAISS 进行索引和检索 """ import os import json import numpy as np import requests from config import INDEX_DIR, OLLAMA_URL, OLLAMA_EMBED_MODEL, TOP_K class VectorStore: def __init__(self): self.index = None self.documents = [] # 存储文档内容和元数据 self.dimension = 768 # nomic-embed-text 的向量维度 self.index_file = os.path.join(INDEX_DIR, "faiss.index") self.docs_file = os.path.join(INDEX_DIR, "documents.json") self.faiss = None def _load_faiss(self): """懒加载 FAISS""" if self.faiss is None: import faiss self.faiss = faiss def _embed_with_ollama(self, text, retry_count=3): """使用 Ollama 生成向量,带重试机制""" import time import urllib.request import urllib.error url = f"{OLLAMA_URL}/api/embeddings" # 确保文本不为空且是字符串 if not text or not isinstance(text, str): text = "empty" # 清理文本中的特殊字符 text = text.replace('\x00', '') # 移除 null 字符 # 截断过长的文本(nomic-embed-text 上下文限制约 2048 tokens) # 中文约 1.5 字符/token,保守设置为 1000 字符 max_length = 1000 if len(text) > max_length: text = text[:max_length] payload = { "model": OLLAMA_EMBED_MODEL, "prompt": text } last_error = None for attempt in range(retry_count): try: # 使用 urllib 代替 requests,避免潜在的编码问题 data = json.dumps(payload, ensure_ascii=False).encode('utf-8') req = urllib.request.Request( url, data=data, headers={'Content-Type': 'application/json; charset=utf-8'}, method='POST' ) with urllib.request.urlopen(req, timeout=120) as response: result = json.loads(response.read().decode('utf-8')) return result.get("embedding", []) except urllib.error.HTTPError as e: last_error = e error_body = e.read().decode('utf-8') if e.fp else 'N/A' print(f"Ollama HTTP 错误 (尝试 {attempt+1}/{retry_count}): {e.code} {e.reason}") print(f"响应内容: {error_body[:500]}") print(f"请求文本长度: {len(text)}") if attempt < retry_count - 1: wait_time = (attempt + 1) * 2 print(f"等待 {wait_time} 秒后重试...") time.sleep(wait_time) except Exception as e: last_error = e print(f"Ollama 嵌入失败 (尝试 {attempt+1}/{retry_count}): {e}") if attempt < retry_count - 1: wait_time = (attempt + 1) * 2 print(f"等待 {wait_time} 秒后重试...") time.sleep(wait_time) raise last_error def _embed_batch(self, texts): """批量生成向量""" import time embeddings = [] for i, text in enumerate(texts): # 打印文本信息用于调试 print(f" 生成向量 {i+1}/{len(texts)}...") print(f" 文本长度: {len(text)}, 前50字符: {repr(text[:50])}") embedding = self._embed_with_ollama(text) embeddings.append(embedding) # 添加小延迟避免请求过快 if i < len(texts) - 1: time.sleep(1.0) return embeddings for i, text in enumerate(texts): print(f" 生成向量 {i+1}/{len(texts)}...") embedding = self._embed_with_ollama(text) embeddings.append(embedding) # 添加小延迟避免请求过快 if i < len(texts) - 1: time.sleep(0.5) return embeddings def _init_index(self): """初始化 FAISS 索引""" self._load_faiss() if self.index is None: self.index = self.faiss.IndexFlatIP(self.dimension) def load_index(self): """从磁盘加载索引""" self._load_faiss() if os.path.exists(self.index_file) and os.path.exists(self.docs_file): try: print("正在加载已有索引...") # FAISS 在 Windows 上不支持中文路径,使用临时文件 import tempfile import shutil try: # 复制到临时文件再读取 with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp: tmp_path = tmp.name shutil.copy2(self.index_file, tmp_path) self.index = self.faiss.read_index(tmp_path) os.unlink(tmp_path) except Exception as e: print(f"临时文件方式失败,尝试直接读取: {e}") self.index = self.faiss.read_index(self.index_file) with open(self.docs_file, 'r', encoding='utf-8') as f: self.documents = json.load(f) print(f"索引加载完成,共 {len(self.documents)} 个文档块") return True except Exception as e: print(f"加载索引失败: {e}") self._init_index() self.documents = [] return False else: print("未找到已有索引,创建新索引") self._init_index() self.documents = [] return False def save_index(self): """保存索引到磁盘""" self._load_faiss() if self.index is not None: # 确保目录存在 os.makedirs(os.path.dirname(self.index_file), exist_ok=True) # FAISS 在 Windows 上不支持中文路径,使用临时文件再移动 import tempfile import shutil try: # 先写入临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp: tmp_path = tmp.name self.faiss.write_index(self.index, tmp_path) # 移动到目标位置 shutil.move(tmp_path, self.index_file) except Exception as e: # 如果临时文件方式失败,尝试直接写入 print(f"临时文件方式失败,尝试直接写入: {e}") self.faiss.write_index(self.index, self.index_file) with open(self.docs_file, 'w', encoding='utf-8') as f: json.dump(self.documents, f, ensure_ascii=False, indent=2) print(f"索引已保存,共 {len(self.documents)} 个文档块") def add_documents(self, chunks, metadata=None): """添加文档块到索引""" if not chunks: return 0 self._load_faiss() self._init_index() # 使用 Ollama 生成向量 print(f"正在为 {len(chunks)} 个文本块生成向量...") embeddings = self._embed_batch(chunks) # 检查向量维度 if embeddings and len(embeddings[0]) != self.dimension: self.dimension = len(embeddings[0]) self.index = self.faiss.IndexFlatIP(self.dimension) print(f"更新向量维度为: {self.dimension}") # 归一化向量(用于余弦相似度) embeddings_np = np.array(embeddings).astype('float32') norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) embeddings_np = embeddings_np / (norms + 1e-10) # 添加到索引 start_idx = len(self.documents) self.index.add(embeddings_np) # 保存文档内容和元数据 for i, chunk in enumerate(chunks): doc = { 'id': start_idx + i, 'content': chunk, 'metadata': metadata or {} } self.documents.append(doc) # 自动保存 self.save_index() return len(chunks) def search(self, query, top_k=None): """搜索相关文档""" if top_k is None: top_k = TOP_K if self.index is None or self.index.ntotal == 0: return [] self._load_faiss() # 生成查询向量 query_embedding = self._embed_with_ollama(query) query_np = np.array([query_embedding]).astype('float32') # 归一化 norm = np.linalg.norm(query_np) query_np = query_np / (norm + 1e-10) # 搜索 k = min(top_k, self.index.ntotal) scores, indices = self.index.search(query_np, k) # 构建结果 results = [] for i, idx in enumerate(indices[0]): if idx < len(self.documents) and idx >= 0: doc = self.documents[idx] results.append({ 'content': doc['content'], 'score': float(scores[0][i]), 'metadata': doc.get('metadata', {}) }) return results def delete_by_filename(self, filename): """删除指定文件的所有文档块""" if not self.documents: return 0 self._load_faiss() # 找出要保留的文档 remaining_docs = [] deleted_count = 0 for doc in self.documents: if doc.get('metadata', {}).get('filename') != filename: remaining_docs.append(doc) else: deleted_count += 1 if deleted_count > 0: # 重建索引 self.documents = [] self.index = self.faiss.IndexFlatIP(self.dimension) if remaining_docs: chunks = [doc['content'] for doc in remaining_docs] metadatas = [doc.get('metadata', {}) for doc in remaining_docs] embeddings = self._embed_batch(chunks) embeddings_np = np.array(embeddings).astype('float32') norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) embeddings_np = embeddings_np / (norms + 1e-10) self.index.add(embeddings_np) for i, (chunk, meta) in enumerate(zip(chunks, metadatas)): self.documents.append({ 'id': i, 'content': chunk, 'metadata': meta }) self.save_index() return deleted_count def clear(self): """清空所有索引""" self._load_faiss() self.index = self.faiss.IndexFlatIP(self.dimension) self.documents = [] self.save_index() print("索引已清空") def get_stats(self): """获取索引统计信息""" files = set() for doc in self.documents: filename = doc.get('metadata', {}).get('filename') if filename: files.add(filename) return { 'total_chunks': len(self.documents), 'total_files': len(files), 'files': list(files) } # 全局实例 vector_store = VectorStore()