xinli/rag-python/vector_store.py
xiao12feng@outlook.com 0f490298f3 加入AI分析知识库
2025-12-20 12:08:33 +08:00

327 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
向量存储 - 使用 Ollama 生成向量FAISS 进行索引和检索
"""
import os
import json
import numpy as np
import requests
from config import INDEX_DIR, OLLAMA_URL, OLLAMA_EMBED_MODEL, TOP_K
class VectorStore:
def __init__(self):
self.index = None
self.documents = [] # 存储文档内容和元数据
self.dimension = 768 # nomic-embed-text 的向量维度
self.index_file = os.path.join(INDEX_DIR, "faiss.index")
self.docs_file = os.path.join(INDEX_DIR, "documents.json")
self.faiss = None
def _load_faiss(self):
"""懒加载 FAISS"""
if self.faiss is None:
import faiss
self.faiss = faiss
def _embed_with_ollama(self, text, retry_count=3):
"""使用 Ollama 生成向量,带重试机制"""
import time
import urllib.request
import urllib.error
url = f"{OLLAMA_URL}/api/embeddings"
# 确保文本不为空且是字符串
if not text or not isinstance(text, str):
text = "empty"
# 清理文本中的特殊字符
text = text.replace('\x00', '') # 移除 null 字符
# 截断过长的文本nomic-embed-text 上下文限制约 2048 tokens
# 中文约 1.5 字符/token保守设置为 1000 字符
max_length = 1000
if len(text) > max_length:
text = text[:max_length]
payload = {
"model": OLLAMA_EMBED_MODEL,
"prompt": text
}
last_error = None
for attempt in range(retry_count):
try:
# 使用 urllib 代替 requests避免潜在的编码问题
data = json.dumps(payload, ensure_ascii=False).encode('utf-8')
req = urllib.request.Request(
url,
data=data,
headers={'Content-Type': 'application/json; charset=utf-8'},
method='POST'
)
with urllib.request.urlopen(req, timeout=120) as response:
result = json.loads(response.read().decode('utf-8'))
return result.get("embedding", [])
except urllib.error.HTTPError as e:
last_error = e
error_body = e.read().decode('utf-8') if e.fp else 'N/A'
print(f"Ollama HTTP 错误 (尝试 {attempt+1}/{retry_count}): {e.code} {e.reason}")
print(f"响应内容: {error_body[:500]}")
print(f"请求文本长度: {len(text)}")
if attempt < retry_count - 1:
wait_time = (attempt + 1) * 2
print(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
except Exception as e:
last_error = e
print(f"Ollama 嵌入失败 (尝试 {attempt+1}/{retry_count}): {e}")
if attempt < retry_count - 1:
wait_time = (attempt + 1) * 2
print(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
raise last_error
def _embed_batch(self, texts):
"""批量生成向量"""
import time
embeddings = []
for i, text in enumerate(texts):
# 打印文本信息用于调试
print(f" 生成向量 {i+1}/{len(texts)}...")
print(f" 文本长度: {len(text)}, 前50字符: {repr(text[:50])}")
embedding = self._embed_with_ollama(text)
embeddings.append(embedding)
# 添加小延迟避免请求过快
if i < len(texts) - 1:
time.sleep(1.0)
return embeddings
for i, text in enumerate(texts):
print(f" 生成向量 {i+1}/{len(texts)}...")
embedding = self._embed_with_ollama(text)
embeddings.append(embedding)
# 添加小延迟避免请求过快
if i < len(texts) - 1:
time.sleep(0.5)
return embeddings
def _init_index(self):
"""初始化 FAISS 索引"""
self._load_faiss()
if self.index is None:
self.index = self.faiss.IndexFlatIP(self.dimension)
def load_index(self):
"""从磁盘加载索引"""
self._load_faiss()
if os.path.exists(self.index_file) and os.path.exists(self.docs_file):
try:
print("正在加载已有索引...")
# FAISS 在 Windows 上不支持中文路径,使用临时文件
import tempfile
import shutil
try:
# 复制到临时文件再读取
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
tmp_path = tmp.name
shutil.copy2(self.index_file, tmp_path)
self.index = self.faiss.read_index(tmp_path)
os.unlink(tmp_path)
except Exception as e:
print(f"临时文件方式失败,尝试直接读取: {e}")
self.index = self.faiss.read_index(self.index_file)
with open(self.docs_file, 'r', encoding='utf-8') as f:
self.documents = json.load(f)
print(f"索引加载完成,共 {len(self.documents)} 个文档块")
return True
except Exception as e:
print(f"加载索引失败: {e}")
self._init_index()
self.documents = []
return False
else:
print("未找到已有索引,创建新索引")
self._init_index()
self.documents = []
return False
def save_index(self):
"""保存索引到磁盘"""
self._load_faiss()
if self.index is not None:
# 确保目录存在
os.makedirs(os.path.dirname(self.index_file), exist_ok=True)
# FAISS 在 Windows 上不支持中文路径,使用临时文件再移动
import tempfile
import shutil
try:
# 先写入临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
tmp_path = tmp.name
self.faiss.write_index(self.index, tmp_path)
# 移动到目标位置
shutil.move(tmp_path, self.index_file)
except Exception as e:
# 如果临时文件方式失败,尝试直接写入
print(f"临时文件方式失败,尝试直接写入: {e}")
self.faiss.write_index(self.index, self.index_file)
with open(self.docs_file, 'w', encoding='utf-8') as f:
json.dump(self.documents, f, ensure_ascii=False, indent=2)
print(f"索引已保存,共 {len(self.documents)} 个文档块")
def add_documents(self, chunks, metadata=None):
"""添加文档块到索引"""
if not chunks:
return 0
self._load_faiss()
self._init_index()
# 使用 Ollama 生成向量
print(f"正在为 {len(chunks)} 个文本块生成向量...")
embeddings = self._embed_batch(chunks)
# 检查向量维度
if embeddings and len(embeddings[0]) != self.dimension:
self.dimension = len(embeddings[0])
self.index = self.faiss.IndexFlatIP(self.dimension)
print(f"更新向量维度为: {self.dimension}")
# 归一化向量(用于余弦相似度)
embeddings_np = np.array(embeddings).astype('float32')
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
embeddings_np = embeddings_np / (norms + 1e-10)
# 添加到索引
start_idx = len(self.documents)
self.index.add(embeddings_np)
# 保存文档内容和元数据
for i, chunk in enumerate(chunks):
doc = {
'id': start_idx + i,
'content': chunk,
'metadata': metadata or {}
}
self.documents.append(doc)
# 自动保存
self.save_index()
return len(chunks)
def search(self, query, top_k=None):
"""搜索相关文档"""
if top_k is None:
top_k = TOP_K
if self.index is None or self.index.ntotal == 0:
return []
self._load_faiss()
# 生成查询向量
query_embedding = self._embed_with_ollama(query)
query_np = np.array([query_embedding]).astype('float32')
# 归一化
norm = np.linalg.norm(query_np)
query_np = query_np / (norm + 1e-10)
# 搜索
k = min(top_k, self.index.ntotal)
scores, indices = self.index.search(query_np, k)
# 构建结果
results = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents) and idx >= 0:
doc = self.documents[idx]
results.append({
'content': doc['content'],
'score': float(scores[0][i]),
'metadata': doc.get('metadata', {})
})
return results
def delete_by_filename(self, filename):
"""删除指定文件的所有文档块"""
if not self.documents:
return 0
self._load_faiss()
# 找出要保留的文档
remaining_docs = []
deleted_count = 0
for doc in self.documents:
if doc.get('metadata', {}).get('filename') != filename:
remaining_docs.append(doc)
else:
deleted_count += 1
if deleted_count > 0:
# 重建索引
self.documents = []
self.index = self.faiss.IndexFlatIP(self.dimension)
if remaining_docs:
chunks = [doc['content'] for doc in remaining_docs]
metadatas = [doc.get('metadata', {}) for doc in remaining_docs]
embeddings = self._embed_batch(chunks)
embeddings_np = np.array(embeddings).astype('float32')
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
embeddings_np = embeddings_np / (norms + 1e-10)
self.index.add(embeddings_np)
for i, (chunk, meta) in enumerate(zip(chunks, metadatas)):
self.documents.append({
'id': i,
'content': chunk,
'metadata': meta
})
self.save_index()
return deleted_count
def clear(self):
"""清空所有索引"""
self._load_faiss()
self.index = self.faiss.IndexFlatIP(self.dimension)
self.documents = []
self.save_index()
print("索引已清空")
def get_stats(self):
"""获取索引统计信息"""
files = set()
for doc in self.documents:
filename = doc.get('metadata', {}).get('filename')
if filename:
files.add(filename)
return {
'total_chunks': len(self.documents),
'total_files': len(files),
'files': list(files)
}
# 全局实例
vector_store = VectorStore()