2025-12-20 12:08:33 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
向量存储 - 使用 Ollama 生成向量,FAISS 进行索引和检索
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
import json
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import requests
|
|
|
|
|
|
from config import INDEX_DIR, OLLAMA_URL, OLLAMA_EMBED_MODEL, TOP_K
|
|
|
|
|
|
|
|
|
|
|
|
class VectorStore:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.index = None
|
|
|
|
|
|
self.documents = [] # 存储文档内容和元数据
|
|
|
|
|
|
self.dimension = 768 # nomic-embed-text 的向量维度
|
|
|
|
|
|
self.index_file = os.path.join(INDEX_DIR, "faiss.index")
|
|
|
|
|
|
self.docs_file = os.path.join(INDEX_DIR, "documents.json")
|
|
|
|
|
|
self.faiss = None
|
|
|
|
|
|
|
|
|
|
|
|
def _load_faiss(self):
|
|
|
|
|
|
"""懒加载 FAISS"""
|
|
|
|
|
|
if self.faiss is None:
|
|
|
|
|
|
import faiss
|
|
|
|
|
|
self.faiss = faiss
|
|
|
|
|
|
|
|
|
|
|
|
def _embed_with_ollama(self, text, retry_count=3):
|
|
|
|
|
|
"""使用 Ollama 生成向量,带重试机制"""
|
|
|
|
|
|
import time
|
|
|
|
|
|
import urllib.request
|
|
|
|
|
|
import urllib.error
|
|
|
|
|
|
|
|
|
|
|
|
url = f"{OLLAMA_URL}/api/embeddings"
|
|
|
|
|
|
|
|
|
|
|
|
# 确保文本不为空且是字符串
|
|
|
|
|
|
if not text or not isinstance(text, str):
|
|
|
|
|
|
text = "empty"
|
|
|
|
|
|
|
|
|
|
|
|
# 清理文本中的特殊字符
|
|
|
|
|
|
text = text.replace('\x00', '') # 移除 null 字符
|
|
|
|
|
|
|
|
|
|
|
|
# 截断过长的文本(nomic-embed-text 上下文限制约 2048 tokens)
|
|
|
|
|
|
# 中文约 1.5 字符/token,保守设置为 1000 字符
|
|
|
|
|
|
max_length = 1000
|
|
|
|
|
|
if len(text) > max_length:
|
|
|
|
|
|
text = text[:max_length]
|
|
|
|
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
|
|
"model": OLLAMA_EMBED_MODEL,
|
|
|
|
|
|
"prompt": text
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
last_error = None
|
|
|
|
|
|
for attempt in range(retry_count):
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用 urllib 代替 requests,避免潜在的编码问题
|
|
|
|
|
|
data = json.dumps(payload, ensure_ascii=False).encode('utf-8')
|
|
|
|
|
|
req = urllib.request.Request(
|
|
|
|
|
|
url,
|
|
|
|
|
|
data=data,
|
|
|
|
|
|
headers={'Content-Type': 'application/json; charset=utf-8'},
|
|
|
|
|
|
method='POST'
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
with urllib.request.urlopen(req, timeout=120) as response:
|
|
|
|
|
|
result = json.loads(response.read().decode('utf-8'))
|
|
|
|
|
|
return result.get("embedding", [])
|
|
|
|
|
|
|
|
|
|
|
|
except urllib.error.HTTPError as e:
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
error_body = e.read().decode('utf-8') if e.fp else 'N/A'
|
|
|
|
|
|
print(f"Ollama HTTP 错误 (尝试 {attempt+1}/{retry_count}): {e.code} {e.reason}")
|
|
|
|
|
|
print(f"响应内容: {error_body[:500]}")
|
|
|
|
|
|
print(f"请求文本长度: {len(text)}")
|
|
|
|
|
|
if attempt < retry_count - 1:
|
|
|
|
|
|
wait_time = (attempt + 1) * 2
|
|
|
|
|
|
print(f"等待 {wait_time} 秒后重试...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
last_error = e
|
|
|
|
|
|
print(f"Ollama 嵌入失败 (尝试 {attempt+1}/{retry_count}): {e}")
|
|
|
|
|
|
if attempt < retry_count - 1:
|
|
|
|
|
|
wait_time = (attempt + 1) * 2
|
|
|
|
|
|
print(f"等待 {wait_time} 秒后重试...")
|
|
|
|
|
|
time.sleep(wait_time)
|
|
|
|
|
|
|
|
|
|
|
|
raise last_error
|
|
|
|
|
|
|
|
|
|
|
|
def _embed_batch(self, texts):
|
|
|
|
|
|
"""批量生成向量"""
|
|
|
|
|
|
import time
|
|
|
|
|
|
embeddings = []
|
|
|
|
|
|
for i, text in enumerate(texts):
|
|
|
|
|
|
# 打印文本信息用于调试
|
|
|
|
|
|
print(f" 生成向量 {i+1}/{len(texts)}...")
|
|
|
|
|
|
print(f" 文本长度: {len(text)}, 前50字符: {repr(text[:50])}")
|
|
|
|
|
|
embedding = self._embed_with_ollama(text)
|
|
|
|
|
|
embeddings.append(embedding)
|
|
|
|
|
|
# 添加小延迟避免请求过快
|
|
|
|
|
|
if i < len(texts) - 1:
|
|
|
|
|
|
time.sleep(1.0)
|
|
|
|
|
|
return embeddings
|
|
|
|
|
|
for i, text in enumerate(texts):
|
|
|
|
|
|
print(f" 生成向量 {i+1}/{len(texts)}...")
|
|
|
|
|
|
embedding = self._embed_with_ollama(text)
|
|
|
|
|
|
embeddings.append(embedding)
|
|
|
|
|
|
# 添加小延迟避免请求过快
|
|
|
|
|
|
if i < len(texts) - 1:
|
|
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
def _init_index(self):
|
|
|
|
|
|
"""初始化 FAISS 索引"""
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
if self.index is None:
|
|
|
|
|
|
self.index = self.faiss.IndexFlatIP(self.dimension)
|
|
|
|
|
|
|
|
|
|
|
|
def load_index(self):
|
|
|
|
|
|
"""从磁盘加载索引"""
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(self.index_file) and os.path.exists(self.docs_file):
|
|
|
|
|
|
try:
|
|
|
|
|
|
print("正在加载已有索引...")
|
|
|
|
|
|
|
|
|
|
|
|
# FAISS 在 Windows 上不支持中文路径,使用临时文件
|
|
|
|
|
|
import tempfile
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 复制到临时文件再读取
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
|
|
|
|
|
|
tmp_path = tmp.name
|
|
|
|
|
|
shutil.copy2(self.index_file, tmp_path)
|
|
|
|
|
|
self.index = self.faiss.read_index(tmp_path)
|
|
|
|
|
|
os.unlink(tmp_path)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"临时文件方式失败,尝试直接读取: {e}")
|
|
|
|
|
|
self.index = self.faiss.read_index(self.index_file)
|
|
|
|
|
|
|
|
|
|
|
|
with open(self.docs_file, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
self.documents = json.load(f)
|
|
|
|
|
|
print(f"索引加载完成,共 {len(self.documents)} 个文档块")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"加载索引失败: {e}")
|
|
|
|
|
|
self._init_index()
|
|
|
|
|
|
self.documents = []
|
|
|
|
|
|
return False
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("未找到已有索引,创建新索引")
|
|
|
|
|
|
self._init_index()
|
|
|
|
|
|
self.documents = []
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def save_index(self):
|
|
|
|
|
|
"""保存索引到磁盘"""
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
if self.index is not None:
|
|
|
|
|
|
# 确保目录存在
|
|
|
|
|
|
os.makedirs(os.path.dirname(self.index_file), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
# FAISS 在 Windows 上不支持中文路径,使用临时文件再移动
|
|
|
|
|
|
import tempfile
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 先写入临时文件
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.index') as tmp:
|
|
|
|
|
|
tmp_path = tmp.name
|
|
|
|
|
|
|
|
|
|
|
|
self.faiss.write_index(self.index, tmp_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 移动到目标位置
|
|
|
|
|
|
shutil.move(tmp_path, self.index_file)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 如果临时文件方式失败,尝试直接写入
|
|
|
|
|
|
print(f"临时文件方式失败,尝试直接写入: {e}")
|
|
|
|
|
|
self.faiss.write_index(self.index, self.index_file)
|
|
|
|
|
|
|
|
|
|
|
|
with open(self.docs_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(self.documents, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
print(f"索引已保存,共 {len(self.documents)} 个文档块")
|
|
|
|
|
|
|
|
|
|
|
|
def add_documents(self, chunks, metadata=None):
|
|
|
|
|
|
"""添加文档块到索引"""
|
|
|
|
|
|
if not chunks:
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
self._init_index()
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 Ollama 生成向量
|
|
|
|
|
|
print(f"正在为 {len(chunks)} 个文本块生成向量...")
|
|
|
|
|
|
embeddings = self._embed_batch(chunks)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查向量维度
|
|
|
|
|
|
if embeddings and len(embeddings[0]) != self.dimension:
|
|
|
|
|
|
self.dimension = len(embeddings[0])
|
|
|
|
|
|
self.index = self.faiss.IndexFlatIP(self.dimension)
|
|
|
|
|
|
print(f"更新向量维度为: {self.dimension}")
|
|
|
|
|
|
|
|
|
|
|
|
# 归一化向量(用于余弦相似度)
|
|
|
|
|
|
embeddings_np = np.array(embeddings).astype('float32')
|
|
|
|
|
|
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
|
|
|
|
|
|
embeddings_np = embeddings_np / (norms + 1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加到索引
|
|
|
|
|
|
start_idx = len(self.documents)
|
|
|
|
|
|
self.index.add(embeddings_np)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存文档内容和元数据
|
|
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
|
|
|
|
doc = {
|
|
|
|
|
|
'id': start_idx + i,
|
|
|
|
|
|
'content': chunk,
|
|
|
|
|
|
'metadata': metadata or {}
|
|
|
|
|
|
}
|
|
|
|
|
|
self.documents.append(doc)
|
|
|
|
|
|
|
|
|
|
|
|
# 自动保存
|
|
|
|
|
|
self.save_index()
|
|
|
|
|
|
|
|
|
|
|
|
return len(chunks)
|
|
|
|
|
|
|
|
|
|
|
|
def search(self, query, top_k=None):
|
|
|
|
|
|
"""搜索相关文档"""
|
|
|
|
|
|
if top_k is None:
|
|
|
|
|
|
top_k = TOP_K
|
|
|
|
|
|
|
|
|
|
|
|
if self.index is None or self.index.ntotal == 0:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
|
|
|
|
|
|
# 生成查询向量
|
|
|
|
|
|
query_embedding = self._embed_with_ollama(query)
|
|
|
|
|
|
query_np = np.array([query_embedding]).astype('float32')
|
|
|
|
|
|
|
|
|
|
|
|
# 归一化
|
|
|
|
|
|
norm = np.linalg.norm(query_np)
|
|
|
|
|
|
query_np = query_np / (norm + 1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
# 搜索
|
|
|
|
|
|
k = min(top_k, self.index.ntotal)
|
|
|
|
|
|
scores, indices = self.index.search(query_np, k)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建结果
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i, idx in enumerate(indices[0]):
|
|
|
|
|
|
if idx < len(self.documents) and idx >= 0:
|
|
|
|
|
|
doc = self.documents[idx]
|
2025-12-20 18:33:07 +08:00
|
|
|
|
metadata = doc.get('metadata', {})
|
2025-12-20 12:08:33 +08:00
|
|
|
|
results.append({
|
|
|
|
|
|
'content': doc['content'],
|
|
|
|
|
|
'score': float(scores[0][i]),
|
2025-12-20 18:33:07 +08:00
|
|
|
|
'filename': metadata.get('filename', '未知来源'),
|
|
|
|
|
|
'similarity': float(scores[0][i]),
|
|
|
|
|
|
'metadata': metadata
|
2025-12-20 12:08:33 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def delete_by_filename(self, filename):
|
|
|
|
|
|
"""删除指定文件的所有文档块"""
|
|
|
|
|
|
if not self.documents:
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
|
|
|
|
|
|
# 找出要保留的文档
|
|
|
|
|
|
remaining_docs = []
|
|
|
|
|
|
deleted_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
for doc in self.documents:
|
|
|
|
|
|
if doc.get('metadata', {}).get('filename') != filename:
|
|
|
|
|
|
remaining_docs.append(doc)
|
|
|
|
|
|
else:
|
|
|
|
|
|
deleted_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
if deleted_count > 0:
|
|
|
|
|
|
# 重建索引
|
|
|
|
|
|
self.documents = []
|
|
|
|
|
|
self.index = self.faiss.IndexFlatIP(self.dimension)
|
|
|
|
|
|
|
|
|
|
|
|
if remaining_docs:
|
|
|
|
|
|
chunks = [doc['content'] for doc in remaining_docs]
|
|
|
|
|
|
metadatas = [doc.get('metadata', {}) for doc in remaining_docs]
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = self._embed_batch(chunks)
|
|
|
|
|
|
embeddings_np = np.array(embeddings).astype('float32')
|
|
|
|
|
|
norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
|
|
|
|
|
|
embeddings_np = embeddings_np / (norms + 1e-10)
|
|
|
|
|
|
self.index.add(embeddings_np)
|
|
|
|
|
|
|
|
|
|
|
|
for i, (chunk, meta) in enumerate(zip(chunks, metadatas)):
|
|
|
|
|
|
self.documents.append({
|
|
|
|
|
|
'id': i,
|
|
|
|
|
|
'content': chunk,
|
|
|
|
|
|
'metadata': meta
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
self.save_index()
|
|
|
|
|
|
|
|
|
|
|
|
return deleted_count
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
|
|
"""清空所有索引"""
|
|
|
|
|
|
self._load_faiss()
|
|
|
|
|
|
self.index = self.faiss.IndexFlatIP(self.dimension)
|
|
|
|
|
|
self.documents = []
|
|
|
|
|
|
self.save_index()
|
|
|
|
|
|
print("索引已清空")
|
|
|
|
|
|
|
|
|
|
|
|
def get_stats(self):
|
|
|
|
|
|
"""获取索引统计信息"""
|
|
|
|
|
|
files = set()
|
|
|
|
|
|
for doc in self.documents:
|
|
|
|
|
|
filename = doc.get('metadata', {}).get('filename')
|
|
|
|
|
|
if filename:
|
|
|
|
|
|
files.add(filename)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'total_chunks': len(self.documents),
|
|
|
|
|
|
'total_files': len(files),
|
|
|
|
|
|
'files': list(files)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 全局实例
|
|
|
|
|
|
vector_store = VectorStore()
|