330 lines
12 KiB
Python
330 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
向量存储 - 使用 Ollama 生成向量,FAISS 进行索引和检索
|
||
"""
|
||
import os
|
||
import json
|
||
import numpy as np
|
||
import requests
|
||
from config import INDEX_DIR, OLLAMA_URL, OLLAMA_EMBED_MODEL, TOP_K
|
||
|
||
class VectorStore:
|
||
def __init__(self):
|
||
self.index = None
|
||
self.documents = [] # 存储文档内容和元数据
|
||
self.dimension = 768 # nomic-embed-text 的向量维度
|
||
self.index_file = os.path.join(INDEX_DIR, "faiss.index")
|
||
self.docs_file = os.path.join(INDEX_DIR, "documents.json")
|
||
self.faiss = None
|
||
|
||
def _load_faiss(self):
|
||
"""懒加载 FAISS"""
|
||
if self.faiss is None:
|
||
import faiss
|
||
self.faiss = faiss
|
||
|
||
def _embed_with_ollama(self, text, retry_count=3):
|
||
"""使用 Ollama 生成向量,带重试机制"""
|
||
import time
|
||
import urllib.request
|
||
import urllib.error
|
||
|
||
url = f"{OLLAMA_URL}/api/embeddings"
|
||
|
||
# 确保文本不为空且是字符串
|
||
if not text or not isinstance(text, str):
|
||
text = "empty"
|
||
|
||
# 清理文本中的特殊字符
|
||
text = text.replace('\x00', '') # 移除 null 字符
|
||
|
||
# 截断过长的文本(nomic-embed-text 上下文限制约 2048 tokens)
|
||
# 中文约 1.5 字符/token,保守设置为 1000 字符
|
||
max_length = 1000
|
||
if len(text) > max_length:
|
||
text = text[:max_length]
|
||
|
||
payload = {
|
||
"model": OLLAMA_EMBED_MODEL,
|
||
"prompt": text
|
||
}
|
||
|
||
last_error = None
|
||
for attempt in range(retry_count):
|
||
try:
|
||
# 使用 urllib 代替 requests,避免潜在的编码问题
|
||
data = json.dumps(payload, ensure_ascii=False).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
url,
|
||
data=data,
|
||
headers={'Content-Type': 'application/json; charset=utf-8'},
|
||
method='POST'
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=120) as response:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
return result.get("embedding", [])
|
||
|
||
except urllib.error.HTTPError as e:
|
||
last_error = e
|
||
error_body = e.read().decode('utf-8') if e.fp else 'N/A'
|
||
print(f"Ollama HTTP 错误 (尝试 {attempt+1}/{retry_count}): {e.code} {e.reason}")
|
||
print(f"响应内容: {error_body[:500]}")
|
||
print(f"请求文本长度: {len(text)}")
|
||
if attempt < retry_count - 1:
|
||
wait_time = (attempt + 1) * 2
|
||
print(f"等待 {wait_time} 秒后重试...")
|
||
time.sleep(wait_time)
|
||
except Exception as e:
|
||
last_error = e
|
||
print(f"Ollama 嵌入失败 (尝试 {attempt+1}/{retry_count}): {e}")
|
||
if attempt < retry_count - 1:
|
||
wait_time = (attempt + 1) * 2
|
||
print(f"等待 {wait_time} 秒后重试...")
|
||
time.sleep(wait_time)
|
||
|
||
raise last_error
|
||
|
||
def _embed_batch(self, texts):
|
||
"""批量生成向量"""
|
||
import time
|
||
embeddings = []
|
||
for i, text in enumerate(texts):
|
||
# 打印文本信息用于调试
|
||
print(f" 生成向量 {i+1}/{len(texts)}...")
|
||
print(f" 文本长度: {len(text)}, 前50字符: {repr(text[:50])}")
|
||
embedding = self._embed_with_ollama(text)
|
||
embeddings.append(embedding)
|
||
# 添加小延迟避免请求过快
|
||
if i < len(texts) - 1:
|
||
time.sleep(1.0)
|
||
return embeddings
|
||
for i, text in enumerate(texts):
|
||
print(f" 生成向量 {i+1}/{len(texts)}...")
|
||
embedding = self._embed_with_ollama(text)
|
||
embeddings.append(embedding)
|
||
# 添加小延迟避免请求过快
|
||
if i < len(texts) - 1:
|
||
time.sleep(0.5)
|
||
return embeddings
|
||
|
||
def _init_index(self):
|
||
"""初始化 FAISS 索引"""
|
||
self._load_faiss()
|
||
if self.index is None:
|
||
self.index = self.faiss.IndexFlatIP(self.dimension)
|
||
|
||
def load_index(self):
|
||
"""从磁盘加载索引"""
|
||
self._load_faiss()
|
||
|
||
if os.path.exists(self.index_file) and os.path.exists(self.docs_file):
|
||
try:
|
||
print("正在加载已有索引...")
|
||
|
||
# FAISS 在 Windows 上不支持中文路径,使用临时文件
|
||
import tempfile
|
||
import shutil
|
||
|
||
try:
|
||
# 复制到临时文件再读取
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
|
||
tmp_path = tmp.name
|
||
shutil.copy2(self.index_file, tmp_path)
|
||
self.index = self.faiss.read_index(tmp_path)
|
||
os.unlink(tmp_path)
|
||
except Exception as e:
|
||
print(f"临时文件方式失败,尝试直接读取: {e}")
|
||
self.index = self.faiss.read_index(self.index_file)
|
||
|
||
with open(self.docs_file, 'r', encoding='utf-8') as f:
|
||
self.documents = json.load(f)
|
||
print(f"索引加载完成,共 {len(self.documents)} 个文档块")
|
||
return True
|
||
except Exception as e:
|
||
print(f"加载索引失败: {e}")
|
||
self._init_index()
|
||
self.documents = []
|
||
return False
|
||
else:
|
||
print("未找到已有索引,创建新索引")
|
||
self._init_index()
|
||
self.documents = []
|
||
return False
|
||
|
||
def save_index(self):
|
||
"""保存索引到磁盘"""
|
||
self._load_faiss()
|
||
if self.index is not None:
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(self.index_file), exist_ok=True)
|
||
|
||
# FAISS 在 Windows 上不支持中文路径,使用临时文件再移动
|
||
import tempfile
|
||
import shutil
|
||
|
||
try:
|
||
# 先写入临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
|
||
tmp_path = tmp.name
|
||
|
||
self.faiss.write_index(self.index, tmp_path)
|
||
|
||
# 移动到目标位置
|
||
shutil.move(tmp_path, self.index_file)
|
||
except Exception as e:
|
||
# 如果临时文件方式失败,尝试直接写入
|
||
print(f"临时文件方式失败,尝试直接写入: {e}")
|
||
self.faiss.write_index(self.index, self.index_file)
|
||
|
||
with open(self.docs_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.documents, f, ensure_ascii=False, indent=2)
|
||
print(f"索引已保存,共 {len(self.documents)} 个文档块")
|
||
|
||
def add_documents(self, chunks, metadata=None):
|
||
"""添加文档块到索引"""
|
||
if not chunks:
|
||
return 0
|
||
|
||
self._load_faiss()
|
||
self._init_index()
|
||
|
||
# 使用 Ollama 生成向量
|
||
print(f"正在为 {len(chunks)} 个文本块生成向量...")
|
||
embeddings = self._embed_batch(chunks)
|
||
|
||
# 检查向量维度
|
||
if embeddings and len(embeddings[0]) != self.dimension:
|
||
self.dimension = len(embeddings[0])
|
||
self.index = self.faiss.IndexFlatIP(self.dimension)
|
||
print(f"更新向量维度为: {self.dimension}")
|
||
|
||
# 归一化向量(用于余弦相似度)
|
||
embeddings_np = np.array(embeddings).astype('float32')
|
||
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
|
||
embeddings_np = embeddings_np / (norms + 1e-10)
|
||
|
||
# 添加到索引
|
||
start_idx = len(self.documents)
|
||
self.index.add(embeddings_np)
|
||
|
||
# 保存文档内容和元数据
|
||
for i, chunk in enumerate(chunks):
|
||
doc = {
|
||
'id': start_idx + i,
|
||
'content': chunk,
|
||
'metadata': metadata or {}
|
||
}
|
||
self.documents.append(doc)
|
||
|
||
# 自动保存
|
||
self.save_index()
|
||
|
||
return len(chunks)
|
||
|
||
def search(self, query, top_k=None):
|
||
"""搜索相关文档"""
|
||
if top_k is None:
|
||
top_k = TOP_K
|
||
|
||
if self.index is None or self.index.ntotal == 0:
|
||
return []
|
||
|
||
self._load_faiss()
|
||
|
||
# 生成查询向量
|
||
query_embedding = self._embed_with_ollama(query)
|
||
query_np = np.array([query_embedding]).astype('float32')
|
||
|
||
# 归一化
|
||
norm = np.linalg.norm(query_np)
|
||
query_np = query_np / (norm + 1e-10)
|
||
|
||
# 搜索
|
||
k = min(top_k, self.index.ntotal)
|
||
scores, indices = self.index.search(query_np, k)
|
||
|
||
# 构建结果
|
||
results = []
|
||
for i, idx in enumerate(indices[0]):
|
||
if idx < len(self.documents) and idx >= 0:
|
||
doc = self.documents[idx]
|
||
metadata = doc.get('metadata', {})
|
||
results.append({
|
||
'content': doc['content'],
|
||
'score': float(scores[0][i]),
|
||
'filename': metadata.get('filename', '未知来源'),
|
||
'similarity': float(scores[0][i]),
|
||
'metadata': metadata
|
||
})
|
||
|
||
return results
|
||
|
||
def delete_by_filename(self, filename):
|
||
"""删除指定文件的所有文档块"""
|
||
if not self.documents:
|
||
return 0
|
||
|
||
self._load_faiss()
|
||
|
||
# 找出要保留的文档
|
||
remaining_docs = []
|
||
deleted_count = 0
|
||
|
||
for doc in self.documents:
|
||
if doc.get('metadata', {}).get('filename') != filename:
|
||
remaining_docs.append(doc)
|
||
else:
|
||
deleted_count += 1
|
||
|
||
if deleted_count > 0:
|
||
# 重建索引
|
||
self.documents = []
|
||
self.index = self.faiss.IndexFlatIP(self.dimension)
|
||
|
||
if remaining_docs:
|
||
chunks = [doc['content'] for doc in remaining_docs]
|
||
metadatas = [doc.get('metadata', {}) for doc in remaining_docs]
|
||
|
||
embeddings = self._embed_batch(chunks)
|
||
embeddings_np = np.array(embeddings).astype('float32')
|
||
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
|
||
embeddings_np = embeddings_np / (norms + 1e-10)
|
||
self.index.add(embeddings_np)
|
||
|
||
for i, (chunk, meta) in enumerate(zip(chunks, metadatas)):
|
||
self.documents.append({
|
||
'id': i,
|
||
'content': chunk,
|
||
'metadata': meta
|
||
})
|
||
|
||
self.save_index()
|
||
|
||
return deleted_count
|
||
|
||
def clear(self):
|
||
"""清空所有索引"""
|
||
self._load_faiss()
|
||
self.index = self.faiss.IndexFlatIP(self.dimension)
|
||
self.documents = []
|
||
self.save_index()
|
||
print("索引已清空")
|
||
|
||
def get_stats(self):
|
||
"""获取索引统计信息"""
|
||
files = set()
|
||
for doc in self.documents:
|
||
filename = doc.get('metadata', {}).get('filename')
|
||
if filename:
|
||
files.add(filename)
|
||
|
||
return {
|
||
'total_chunks': len(self.documents),
|
||
'total_files': len(files),
|
||
'files': list(files)
|
||
}
|
||
|
||
# 全局实例
|
||
vector_store = VectorStore()
|