xinli/rag-python/vector_store.py

330 lines
12 KiB
Python
Raw Permalink Normal View History

2025-12-20 12:08:33 +08:00
# -*- coding: utf-8 -*-
"""
向量存储 - 使用 Ollama 生成向量FAISS 进行索引和检索
"""
import os
import json
import numpy as np
import requests
from config import INDEX_DIR, OLLAMA_URL, OLLAMA_EMBED_MODEL, TOP_K
class VectorStore:
def __init__(self):
self.index = None
self.documents = [] # 存储文档内容和元数据
self.dimension = 768 # nomic-embed-text 的向量维度
self.index_file = os.path.join(INDEX_DIR, "faiss.index")
self.docs_file = os.path.join(INDEX_DIR, "documents.json")
self.faiss = None
def _load_faiss(self):
"""懒加载 FAISS"""
if self.faiss is None:
import faiss
self.faiss = faiss
def _embed_with_ollama(self, text, retry_count=3):
"""使用 Ollama 生成向量,带重试机制"""
import time
import urllib.request
import urllib.error
url = f"{OLLAMA_URL}/api/embeddings"
# 确保文本不为空且是字符串
if not text or not isinstance(text, str):
text = "empty"
# 清理文本中的特殊字符
text = text.replace('\x00', '') # 移除 null 字符
# 截断过长的文本nomic-embed-text 上下文限制约 2048 tokens
# 中文约 1.5 字符/token保守设置为 1000 字符
max_length = 1000
if len(text) > max_length:
text = text[:max_length]
payload = {
"model": OLLAMA_EMBED_MODEL,
"prompt": text
}
last_error = None
for attempt in range(retry_count):
try:
# 使用 urllib 代替 requests避免潜在的编码问题
data = json.dumps(payload, ensure_ascii=False).encode('utf-8')
req = urllib.request.Request(
url,
data=data,
headers={'Content-Type': 'application/json; charset=utf-8'},
method='POST'
)
with urllib.request.urlopen(req, timeout=120) as response:
result = json.loads(response.read().decode('utf-8'))
return result.get("embedding", [])
except urllib.error.HTTPError as e:
last_error = e
error_body = e.read().decode('utf-8') if e.fp else 'N/A'
print(f"Ollama HTTP 错误 (尝试 {attempt+1}/{retry_count}): {e.code} {e.reason}")
print(f"响应内容: {error_body[:500]}")
print(f"请求文本长度: {len(text)}")
if attempt < retry_count - 1:
wait_time = (attempt + 1) * 2
print(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
except Exception as e:
last_error = e
print(f"Ollama 嵌入失败 (尝试 {attempt+1}/{retry_count}): {e}")
if attempt < retry_count - 1:
wait_time = (attempt + 1) * 2
print(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
raise last_error
def _embed_batch(self, texts):
"""批量生成向量"""
import time
embeddings = []
for i, text in enumerate(texts):
# 打印文本信息用于调试
print(f" 生成向量 {i+1}/{len(texts)}...")
print(f" 文本长度: {len(text)}, 前50字符: {repr(text[:50])}")
embedding = self._embed_with_ollama(text)
embeddings.append(embedding)
# 添加小延迟避免请求过快
if i < len(texts) - 1:
time.sleep(1.0)
return embeddings
for i, text in enumerate(texts):
print(f" 生成向量 {i+1}/{len(texts)}...")
embedding = self._embed_with_ollama(text)
embeddings.append(embedding)
# 添加小延迟避免请求过快
if i < len(texts) - 1:
time.sleep(0.5)
return embeddings
def _init_index(self):
"""初始化 FAISS 索引"""
self._load_faiss()
if self.index is None:
self.index = self.faiss.IndexFlatIP(self.dimension)
def load_index(self):
"""从磁盘加载索引"""
self._load_faiss()
if os.path.exists(self.index_file) and os.path.exists(self.docs_file):
try:
print("正在加载已有索引...")
# FAISS 在 Windows 上不支持中文路径,使用临时文件
import tempfile
import shutil
try:
# 复制到临时文件再读取
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
tmp_path = tmp.name
shutil.copy2(self.index_file, tmp_path)
self.index = self.faiss.read_index(tmp_path)
os.unlink(tmp_path)
except Exception as e:
print(f"临时文件方式失败,尝试直接读取: {e}")
self.index = self.faiss.read_index(self.index_file)
with open(self.docs_file, 'r', encoding='utf-8') as f:
self.documents = json.load(f)
print(f"索引加载完成,共 {len(self.documents)} 个文档块")
return True
except Exception as e:
print(f"加载索引失败: {e}")
self._init_index()
self.documents = []
return False
else:
print("未找到已有索引,创建新索引")
self._init_index()
self.documents = []
return False
def save_index(self):
"""保存索引到磁盘"""
self._load_faiss()
if self.index is not None:
# 确保目录存在
os.makedirs(os.path.dirname(self.index_file), exist_ok=True)
# FAISS 在 Windows 上不支持中文路径,使用临时文件再移动
import tempfile
import shutil
try:
# 先写入临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
tmp_path = tmp.name
self.faiss.write_index(self.index, tmp_path)
# 移动到目标位置
shutil.move(tmp_path, self.index_file)
except Exception as e:
# 如果临时文件方式失败,尝试直接写入
print(f"临时文件方式失败,尝试直接写入: {e}")
self.faiss.write_index(self.index, self.index_file)
with open(self.docs_file, 'w', encoding='utf-8') as f:
json.dump(self.documents, f, ensure_ascii=False, indent=2)
print(f"索引已保存,共 {len(self.documents)} 个文档块")
def add_documents(self, chunks, metadata=None):
"""添加文档块到索引"""
if not chunks:
return 0
self._load_faiss()
self._init_index()
# 使用 Ollama 生成向量
print(f"正在为 {len(chunks)} 个文本块生成向量...")
embeddings = self._embed_batch(chunks)
# 检查向量维度
if embeddings and len(embeddings[0]) != self.dimension:
self.dimension = len(embeddings[0])
self.index = self.faiss.IndexFlatIP(self.dimension)
print(f"更新向量维度为: {self.dimension}")
# 归一化向量(用于余弦相似度)
embeddings_np = np.array(embeddings).astype('float32')
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
embeddings_np = embeddings_np / (norms + 1e-10)
# 添加到索引
start_idx = len(self.documents)
self.index.add(embeddings_np)
# 保存文档内容和元数据
for i, chunk in enumerate(chunks):
doc = {
'id': start_idx + i,
'content': chunk,
'metadata': metadata or {}
}
self.documents.append(doc)
# 自动保存
self.save_index()
return len(chunks)
def search(self, query, top_k=None):
"""搜索相关文档"""
if top_k is None:
top_k = TOP_K
if self.index is None or self.index.ntotal == 0:
return []
self._load_faiss()
# 生成查询向量
query_embedding = self._embed_with_ollama(query)
query_np = np.array([query_embedding]).astype('float32')
# 归一化
norm = np.linalg.norm(query_np)
query_np = query_np / (norm + 1e-10)
# 搜索
k = min(top_k, self.index.ntotal)
scores, indices = self.index.search(query_np, k)
# 构建结果
results = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents) and idx >= 0:
doc = self.documents[idx]
2025-12-20 18:33:07 +08:00
metadata = doc.get('metadata', {})
2025-12-20 12:08:33 +08:00
results.append({
'content': doc['content'],
'score': float(scores[0][i]),
2025-12-20 18:33:07 +08:00
'filename': metadata.get('filename', '未知来源'),
'similarity': float(scores[0][i]),
'metadata': metadata
2025-12-20 12:08:33 +08:00
})
return results
def delete_by_filename(self, filename):
"""删除指定文件的所有文档块"""
if not self.documents:
return 0
self._load_faiss()
# 找出要保留的文档
remaining_docs = []
deleted_count = 0
for doc in self.documents:
if doc.get('metadata', {}).get('filename') != filename:
remaining_docs.append(doc)
else:
deleted_count += 1
if deleted_count > 0:
# 重建索引
self.documents = []
self.index = self.faiss.IndexFlatIP(self.dimension)
if remaining_docs:
chunks = [doc['content'] for doc in remaining_docs]
metadatas = [doc.get('metadata', {}) for doc in remaining_docs]
embeddings = self._embed_batch(chunks)
embeddings_np = np.array(embeddings).astype('float32')
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
embeddings_np = embeddings_np / (norms + 1e-10)
self.index.add(embeddings_np)
for i, (chunk, meta) in enumerate(zip(chunks, metadatas)):
self.documents.append({
'id': i,
'content': chunk,
'metadata': meta
})
self.save_index()
return deleted_count
def clear(self):
"""清空所有索引"""
self._load_faiss()
self.index = self.faiss.IndexFlatIP(self.dimension)
self.documents = []
self.save_index()
print("索引已清空")
def get_stats(self):
"""获取索引统计信息"""
files = set()
for doc in self.documents:
filename = doc.get('metadata', {}).get('filename')
if filename:
files.add(filename)
return {
'total_chunks': len(self.documents),
'total_files': len(files),
'files': list(files)
}
# 全局实例
vector_store = VectorStore()