xinli/rag-python/text_splitter.py
2026-02-24 16:49:05 +08:00

88 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
文本分块器 - 将长文本分割成小块
"""
import re
from config import CHUNK_SIZE, CHUNK_OVERLAP
def split_text(text, chunk_size=None, chunk_overlap=None):
"""
将文本分割成小块
Args:
text: 要分割的文本
chunk_size: 每块的最大字符数
chunk_overlap: 块之间的重叠字符数
Returns:
文本块列表
"""
if chunk_size is None:
chunk_size = CHUNK_SIZE
if chunk_overlap is None:
chunk_overlap = CHUNK_OVERLAP
if not text or not text.strip():
return []
# 清理文本
text = text.strip()
text = re.sub(r'\n{3,}', '\n\n', text) # 多个换行合并
text = re.sub(r' {2,}', ' ', text) # 多个空格合并
# 按段落分割
paragraphs = re.split(r'\n\n+', text)
chunks = []
current_chunk = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
# 如果当前段落本身就超过chunk_size需要进一步分割
if len(para) > chunk_size:
# 先保存当前chunk
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = ""
# 按句子分割长段落
sentences = re.split(r'([。!?.!?])', para)
temp_chunk = ""
for i in range(0, len(sentences), 2):
sentence = sentences[i]
if i + 1 < len(sentences):
sentence += sentences[i + 1]
if len(temp_chunk) + len(sentence) <= chunk_size:
temp_chunk += sentence
else:
if temp_chunk:
chunks.append(temp_chunk.strip())
temp_chunk = sentence[-chunk_overlap:] + sentence if len(sentence) > chunk_overlap else sentence
if temp_chunk:
current_chunk = temp_chunk
else:
# 检查是否可以添加到当前chunk
if len(current_chunk) + len(para) + 1 <= chunk_size:
current_chunk += ("\n" if current_chunk else "") + para
else:
# 保存当前chunk开始新的
if current_chunk:
chunks.append(current_chunk.strip())
# 保留重叠部分
if len(current_chunk) > chunk_overlap:
current_chunk = current_chunk[-chunk_overlap:] + "\n" + para
else:
current_chunk = para
# 保存最后一个chunk
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks