Ai_GirlFriend/lover/routers/voice_call.py
2026-03-05 17:18:04 +08:00

1668 lines
76 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import logging
import re
import time
from typing import List, Optional
import requests
import dashscope
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status, UploadFile, File
from fastapi.websockets import WebSocketState
from fastapi.responses import JSONResponse
from ..config import settings
from ..deps import AuthedUser, get_current_user, _fetch_user_from_php
from ..llm import chat_completion_stream
from ..tts import synthesize
from ..db import SessionLocal
from ..models import Lover, VoiceLibrary, User
try:
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
except Exception: # dashscope 未安装时提供兜底
Recognition = None
RecognitionCallback = object # type: ignore
RecognitionResult = object # type: ignore
try:
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer, ResultCallback
except Exception:
AudioFormat = None # type: ignore
SpeechSynthesizer = None # type: ignore
ResultCallback = object # type: ignore
router = APIRouter(prefix="/voice", tags=["voice"])
logger = logging.getLogger("voice_call")
logger.setLevel(logging.INFO)
END_OF_TTS = "<<VOICE_CALL_TTS_END>>"
@router.get("/call/duration")
async def get_call_duration(user: AuthedUser = Depends(get_current_user)):
"""获取用户的语音通话时长配置"""
from ..db import SessionLocal
from ..models import User
from datetime import datetime
db = SessionLocal()
try:
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# 检查 VIP 状态vip_endtime 是 Unix 时间戳)
current_timestamp = int(datetime.utcnow().timestamp())
is_vip = user_row.vip_endtime and user_row.vip_endtime > current_timestamp
if is_vip:
duration = 0 # 0 表示无限制
else:
duration = 300000 # 普通用户 5 分钟
from ..response import success_response
return success_response({
"duration": duration,
"is_vip": is_vip
})
finally:
db.close()
class WSRecognitionCallback(RecognitionCallback): # type: ignore[misc]
"""ASR 回调,将句子级结果推入会话队列。"""
def __init__(self, session: "VoiceCallSession"):
super().__init__()
self.session = session
self._last_text: Optional[str] = None
def on_open(self) -> None:
logger.info("ASR connection opened")
def on_complete(self) -> None:
logger.info("ASR complete")
if self._last_text:
# 将最后的部分作为一句结束,防止没有 end 标记时丢失
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on complete: %s", self._last_text)
self._last_text = None
def on_error(self, result: RecognitionResult) -> None:
logger.error("ASR error: %s", getattr(result, "message", None) or result)
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on error: %s", self._last_text)
self._last_text = None
def on_close(self) -> None:
logger.info("ASR closed")
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on close: %s", self._last_text)
self._last_text = None
def on_event(self, result: RecognitionResult) -> None:
sentence = result.get_sentence()
if not sentence:
return
sentences = sentence if isinstance(sentence, list) else [sentence]
for sent in sentences:
text = sent.get("text") if isinstance(sent, dict) else None
if not text:
continue
is_end = False
if isinstance(sent, dict):
is_end = (
bool(sent.get("is_sentence_end"))
or bool(sent.get("sentence_end"))
or RecognitionResult.is_sentence_end(sent)
)
if is_end:
self.session._schedule(self.session.handle_sentence(text))
self._last_text = None
else:
self.session._schedule(self.session.send_signal({"type": "partial_asr", "text": text}))
self._last_text = text
logger.info("ASR event end=%s sentence=%s", is_end, sent)
async def authenticate_websocket(websocket: WebSocket) -> AuthedUser:
"""复用 HTTP 鉴权逻辑Authorization / X-Token / x_user_id调试"""
headers = websocket.headers
token = None
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = headers.get("x-token")
# 支持 query 携带
if not token:
token = websocket.query_params.get("token")
x_user_id = websocket.query_params.get("x_user_id")
if token:
payload = _fetch_user_from_php(token)
user_id = payload.get("id") or payload.get("user_id")
reg_step = payload.get("reg_step") or payload.get("stage") or 1
gender = payload.get("gender") or 0
nickname = payload.get("nickname") or payload.get("username") or ""
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户中心缺少用户ID")
return AuthedUser(
id=user_id,
reg_step=reg_step,
gender=gender,
nickname=nickname,
token=token,
)
if x_user_id is not None:
try:
uid = int(x_user_id)
except Exception:
uid = None
if uid is not None:
return AuthedUser(id=uid, reg_step=2, gender=0, nickname="debug-user", token="")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或未授权")
class VoiceCallSession:
def __init__(self, websocket: WebSocket, user: AuthedUser, require_ptt: bool = False):
self.websocket = websocket
self.user = user
self.require_ptt = require_ptt
self.mic_enabled = not require_ptt
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.asr_to_llm: asyncio.Queue[str] = asyncio.Queue()
self.llm_to_tts: asyncio.Queue[str] = asyncio.Queue()
self.is_speaking = False
self.lover: Optional[Lover] = None
self.db = SessionLocal()
self.voice_code: Optional[str] = None
self.history: List[dict] = [
{
"role": "system",
"content": self._compose_system_prompt(),
}
]
self.llm_task: Optional[asyncio.Task] = None
self.tts_task: Optional[asyncio.Task] = None
self.tts_stream_task: Optional[asyncio.Task] = None
self.silence_task: Optional[asyncio.Task] = None
self.cancel_event = asyncio.Event()
self.recognition: Optional[Recognition] = None
self.idle_task: Optional[asyncio.Task] = None
self.last_activity = time.time()
self.last_voice_activity = time.time()
self.has_voice_input = False
self.last_interrupt_time = 0.0
self.tts_first_chunk = True
async def start(self):
await self.websocket.accept()
self.loop = asyncio.get_running_loop()
# 预加载恋人与音色,避免在流式环节阻塞事件循环
self._prepare_profile()
# 不启动实时ASR避免MP3格式冲突
# 使用批量ASR处理音频
logger.info("🔄 跳过实时ASR启动将使用批量ASR处理MP3音频")
# 启动 LLM/TTS 后台任务
self.llm_task = asyncio.create_task(self._process_llm_loop())
self.tts_task = asyncio.create_task(self._process_tts_loop())
self.idle_task = asyncio.create_task(self._idle_watchdog())
self.silence_task = asyncio.create_task(self._silence_watchdog())
await self.send_signal({"type": "ready"})
if self.require_ptt:
await self.send_signal({"type": "info", "msg": "ptt_enabled"})
def _start_asr(self):
# 注意由于前端发送的是MP3格式音频实时ASR可能无法正常工作
# 主要依赖finalize_asr中的批量ASR处理
logger.info("启动ASR会话主要用于WebSocket连接实际识别使用批量API")
if Recognition is None:
logger.warning("未安装 dashscope跳过实时ASR启动")
return
if not settings.DASHSCOPE_API_KEY:
logger.warning("未配置 DASHSCOPE_API_KEY跳过实时ASR启动")
return
try:
dashscope.api_key = settings.DASHSCOPE_API_KEY
callback = WSRecognitionCallback(self)
# 启动实时ASR可能因为格式问题无法正常工作但保持连接
self.recognition = Recognition(
model=settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
format="pcm", # 保持PCM格式配置
sample_rate=settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
api_key=settings.DASHSCOPE_API_KEY,
callback=callback,
max_sentence_silence=10000, # 句子间最大静音时间 10秒
)
logger.info(
"实时ASR已启动 model=%s sample_rate=%s (注意主要使用批量ASR处理MP3音频)",
settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
)
self.recognition.start()
except Exception as e:
logger.warning(f"实时ASR启动失败将完全依赖批量ASR: {e}")
self.recognition = None
async def handle_sentence(self, text: str):
# 回合制AI 说话时忽略用户语音,提示稍后再说
if self.is_speaking:
await self.send_signal({"type": "info", "msg": "请等待 AI 说完再讲话"})
return
logger.info("Handle sentence: %s", text)
await self.asr_to_llm.put(text)
async def _process_llm_loop(self):
while True:
text = await self.asr_to_llm.get()
self.cancel_event.clear()
try:
await self._stream_llm(text)
except asyncio.CancelledError:
break
except Exception as exc:
logger.exception("LLM error", exc_info=exc)
await self.send_signal({"type": "error", "msg": "LLM 生成失败"})
self.is_speaking = False
async def _stream_llm(self, text: str):
self.history.append({"role": "user", "content": text})
# 控制历史长度
if len(self.history) > settings.VOICE_CALL_MAX_HISTORY:
self.history = self.history[-settings.VOICE_CALL_MAX_HISTORY :]
stream = chat_completion_stream(self.history)
self.is_speaking = True
self.tts_first_chunk = True
buffer = []
for chunk in stream:
if self.cancel_event.is_set():
break
buffer.append(chunk)
await self.llm_to_tts.put(chunk)
if not self.cancel_event.is_set():
await self.llm_to_tts.put(END_OF_TTS)
full_reply = "".join(buffer)
self.history.append({"role": "assistant", "content": full_reply})
if full_reply:
# 下行完整文本,便于前端展示/调试
await self.send_signal({"type": "reply_text", "text": full_reply})
else:
self.is_speaking = False
async def _process_tts_loop(self):
temp_buffer = []
punctuations = set(",。?!,.?!;")
while True:
token = await self.llm_to_tts.get()
if self.cancel_event.is_set():
temp_buffer = []
self.tts_first_chunk = True
continue
if token == END_OF_TTS:
# 将残余缓冲送出
if temp_buffer:
text_chunk = "".join(temp_buffer)
temp_buffer = []
clean_text = self._clean_tts_text(text_chunk)
if clean_text:
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
continue
self.tts_first_chunk = True
self.is_speaking = False
await self.send_signal({"type": "reply_end"})
continue
temp_buffer.append(token)
last_char = token[-1] if token else ""
threshold = 8 if self.tts_first_chunk else 18
if last_char in punctuations or len("".join(temp_buffer)) >= threshold:
text_chunk = "".join(temp_buffer)
temp_buffer = []
self.tts_first_chunk = False
clean_text = self._clean_tts_text(text_chunk)
if not clean_text:
continue
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
# 不可达,但保留以防逻辑调整
async def _synthesize_stream(self, text: str):
"""
调用 cosyvoice v2 流式合成,逐 chunk 返回。
如流式不可用则回落一次性合成。
"""
model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
voice = self._pick_voice_code() or settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
fmt = settings.VOICE_CALL_TTS_FORMAT.lower() if settings.VOICE_CALL_TTS_FORMAT else "mp3"
audio_format = AudioFormat.MP3_22050HZ_MONO_256KBPS if fmt == "mp3" else AudioFormat.PCM_16000HZ_MONO
# 直接同步合成,避免流式阻塞
audio_bytes, _fmt_name = synthesize(text, model=model, voice=voice, audio_format=audio_format) # type: ignore[arg-type]
yield audio_bytes
async def feed_audio(self, data: bytes):
logger.info(f"📥 feed_audio 被调用,数据大小: {len(data)} 字节")
if self.require_ptt and not self.mic_enabled:
# PTT 模式下未按住说话时丢弃音频
logger.warning("⚠️ PTT 模式下 mic 未启用,丢弃音频")
self._touch()
return
# 累积音频数据因为前端发送的是完整的MP3文件分块
if not hasattr(self, '_audio_buffer'):
self._audio_buffer = bytearray()
self._audio_buffer.extend(data)
logger.info(f"📦 累积音频数据,当前缓冲区大小: {len(self._audio_buffer)} 字节")
# 不启动实时ASR避免MP3格式冲突
# 所有音频处理都在finalize_asr中使用批量API完成
logger.info("🔄 跳过实时ASR启动使用批量ASR处理MP3音频")
logger.debug("recv audio chunk bytes=%s", len(data))
# 简单的活跃检测(基于数据大小)
if len(data) > 100: # 有实际音频数据
self.last_voice_activity = time.time()
self.has_voice_input = True
logger.info(f"🎤 检测到音频数据块")
self._touch()
def finalize_asr(self):
"""主动停止 ASR促使返回最终结果。"""
try:
# 处理累积的音频数据
if hasattr(self, '_audio_buffer') and len(self._audio_buffer) > 0:
logger.info(f"🎵 处理累积的音频数据,大小: {len(self._audio_buffer)} 字节")
# 直接使用批量ASR API处理MP3数据避免格式转换问题
try:
logger.info("🔄 使用批量ASR API处理MP3音频...")
import tempfile
import os
from dashscope.audio.asr import Transcription
from ..oss_utils import upload_audio_file, delete_audio_file
# 上传音频到OSS
file_url = upload_audio_file(bytes(self._audio_buffer), "mp3")
logger.info(f"📤 音频已上传到OSS: {file_url}")
# 调用批量ASR
task_response = Transcription.async_call(
model='paraformer-v2',
file_urls=[file_url],
parameters={
'format': 'mp3',
'sample_rate': 16000,
'enable_words': False
}
)
if task_response.status_code == 200:
task_id = task_response.output.task_id
logger.info(f"📋 批量ASR任务创建成功: {task_id}")
# 等待结果最多30秒
import time
max_wait = 30
start_time = time.time()
while time.time() - start_time < max_wait:
try:
result = Transcription.wait(task=task_id)
if result.status_code == 200:
if result.output.task_status == "SUCCEEDED":
logger.info("✅ 批量ASR识别成功")
# 解析结果并触发对话
text_result = ""
if result.output.results:
for item in result.output.results:
if isinstance(item, dict) and 'transcription_url' in item:
# 下载转录结果
import requests
resp = requests.get(item['transcription_url'], timeout=10)
if resp.status_code == 200:
transcription_data = resp.json()
if 'transcripts' in transcription_data:
for transcript in transcription_data['transcripts']:
if 'text' in transcript:
text_result += transcript['text'].strip() + " "
text_result = text_result.strip()
if text_result:
logger.info(f"🎯 批量ASR识别结果: {text_result}")
# 触发对话流程
self._schedule(self.handle_sentence(text_result))
else:
logger.warning("批量ASR未识别到文本内容")
self._schedule(self.handle_sentence("我听到了你的声音,但没有识别到具体内容"))
break
elif result.output.task_status == "FAILED":
error_code = getattr(result.output, 'code', 'Unknown')
logger.error(f"批量ASR任务失败: {error_code}")
if error_code == "SUCCESS_WITH_NO_VALID_FRAGMENT":
self._schedule(self.handle_sentence("我没有听到清晰的语音,请再说一遍"))
else:
self._schedule(self.handle_sentence("语音识别遇到了问题,请重试"))
break
else:
# 任务还在处理中,继续等待
time.sleep(2)
continue
else:
logger.error(f"批量ASR查询失败: {result.status_code}")
break
except Exception as wait_error:
logger.error(f"等待批量ASR结果失败: {wait_error}")
break
# 如果超时或失败,提供备用回复
if time.time() - start_time >= max_wait:
logger.warning("批量ASR处理超时")
self._schedule(self.handle_sentence("语音处理时间较长,我听到了你的声音"))
else:
logger.error(f"批量ASR任务创建失败: {task_response.status_code}")
self._schedule(self.handle_sentence("语音识别服务暂时不可用"))
# 清理OSS文件
try:
delete_audio_file(file_url)
logger.info("OSS临时文件已清理")
except:
pass
except Exception as batch_error:
logger.error(f"❌ 批量ASR处理失败: {batch_error}")
# 最后的备用方案:返回一个友好的消息
self._schedule(self.handle_sentence("我听到了你的声音,语音识别功能正在优化中"))
# 清空缓冲区
self._audio_buffer = bytearray()
# 停止实时ASR识别如果在运行
if self.recognition:
self.recognition.stop()
logger.info("实时ASR已停止")
except Exception as exc:
logger.warning("ASR finalize failed: %s", exc)
# 确保即使出错也能给用户反馈
try:
self._schedule(self.handle_sentence("我听到了你的声音"))
except:
pass
async def set_mic_enabled(self, enabled: bool, flush: bool = False):
if not self.require_ptt:
return
self.mic_enabled = enabled
await self.send_signal({"type": "info", "msg": "mic_on" if enabled else "mic_off"})
if not enabled and flush:
self.finalize_asr()
def _schedule(self, coro):
if self.loop:
self.loop.call_soon_threadsafe(asyncio.create_task, coro)
def _pick_voice_code(self) -> Optional[str]:
"""根据恋人配置或默认音色选择 voice_code。"""
if self.voice_code:
return self.voice_code
self._prepare_profile()
return self.voice_code
async def _interrupt(self):
self.cancel_event.set()
# 清空队列
while not self.llm_to_tts.empty():
try:
self.llm_to_tts.get_nowait()
except Exception:
break
await self.send_signal({"type": "interrupt", "code": "interrupted", "msg": "AI 打断,停止播放"})
self.is_speaking = False
self.last_interrupt_time = time.time()
async def close(self):
if self.db:
try:
self.db.close()
except Exception:
pass
if self.recognition:
try:
self.recognition.stop()
except Exception:
pass
if self.llm_task:
self.llm_task.cancel()
if self.tts_task:
self.tts_task.cancel()
if self.tts_stream_task:
self.tts_stream_task.cancel()
if self.idle_task:
self.idle_task.cancel()
if self.silence_task:
self.silence_task.cancel()
if self.websocket.client_state == WebSocketState.CONNECTED:
await self.websocket.close()
async def send_signal(self, payload: dict):
if self.websocket.client_state != WebSocketState.CONNECTED:
return
try:
await self.websocket.send_text(json.dumps(payload, ensure_ascii=False))
self._touch()
except WebSocketDisconnect:
return
def _load_lover(self) -> Optional[Lover]:
if self.lover is not None:
return self.lover
try:
self.lover = self.db.query(Lover).filter(Lover.user_id == self.user.id).first()
except Exception as exc:
logger.warning("Load lover failed: %s", exc)
self.lover = None
return self.lover
def _compose_system_prompt(self) -> str:
parts = [
f"你是用户 {self.user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天,不要使用 Markdown 符号,不要输出表情、波浪线、星号或动作描述。",
"回复必须是对话内容,不要包含括号/星号/动作描写/舞台指令,不要用拟声词凑字数,保持简短自然的中文口语句子。",
"禁止涉政、违法、暴力、未成年相关内容。",
]
lover = self._load_lover()
if lover and lover.personality_prompt:
parts.append(f"人格设定:{lover.personality_prompt}")
return "\n".join(parts)
@staticmethod
def _clean_tts_text(text: str) -> str:
if not text:
return ""
# 去掉常见 Markdown/代码标记,保留文字内容
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
text = re.sub(r"`([^`]*)`", r"\1", text)
text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
text = re.sub(r"\*[^\*]{0,80}\*", "", text) # 去掉 *动作* 片段
text = re.sub(r"[~]+", "", text) # 去掉波浪线
text = text.replace("*", "")
text = re.sub(r"\s+", " ", text)
return text.strip()
def _prepare_profile(self) -> None:
"""预加载恋人和音色,避免在流式阶段阻塞事件循环。"""
try:
lover = self._load_lover()
if lover and lover.voice_id:
voice = self.db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first()
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
gender = None
if lover and lover.gender:
gender = lover.gender
if not gender:
gender = "female" if (self.user.gender or 0) == 1 else "male"
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender)
.order_by(VoiceLibrary.id.asc())
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
except Exception as exc:
logger.warning("Prepare profile failed: %s", exc)
def _touch(self):
self.last_activity = time.time()
async def _idle_watchdog(self):
timeout = settings.VOICE_CALL_IDLE_TIMEOUT or 0
if timeout <= 0:
return
try:
while True:
await asyncio.sleep(5)
if time.time() - self.last_activity > timeout:
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
async def _silence_watchdog(self):
"""长时间静默时关闭会话ASR 常驻不再因短静音 stop。"""
try:
while True:
await asyncio.sleep(1.0)
if time.time() - self.last_voice_activity > 60:
logger.info("Long silence, closing session")
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
@staticmethod
def _peak_pcm16(data: bytes) -> int:
"""快速估算 PCM 16bit 峰值幅度。"""
if not data:
return 0
view = memoryview(data)
# 每 2 字节一采样,取绝对值最大
max_val = 0
for i in range(0, len(view) - 1, 2):
sample = int.from_bytes(view[i : i + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > max_val:
max_val = sample
return max_val
@router.post("/call/asr")
async def json_asr(
request: dict,
user: AuthedUser = Depends(get_current_user)
):
"""JSON ASR接收 base64 编码的音频数据并返回识别结果"""
try:
# 从请求中提取音频数据
if 'audio_data' not in request:
logger.error("请求中缺少 audio_data 字段")
raise HTTPException(status_code=400, detail="缺少 audio_data 字段")
audio_base64 = request['audio_data']
audio_format = request.get('format', 'mp3')
logger.info(f"收到 JSON ASR 请求,格式: {audio_format}")
# 解码 base64 音频数据
try:
import base64
audio_data = base64.b64decode(audio_base64)
logger.info(f"解码音频数据成功,大小: {len(audio_data)} 字节")
except Exception as decode_error:
logger.error(f"base64 解码失败: {decode_error}")
raise HTTPException(status_code=400, detail="音频数据解码失败")
# 检查音频数据是否为空
if not audio_data:
logger.error("解码后的音频数据为空")
raise HTTPException(status_code=400, detail="音频数据为空")
# 计算预期的音频时长
if audio_format.lower() == 'mp3':
# MP3 文件,粗略估算时长
expected_duration = len(audio_data) / 16000 # 粗略估算
logger.info(f"MP3 音频数据,预估时长: {expected_duration:.2f}")
else:
# PCM 格式16kHz 单声道 16bit每秒需要 32000 字节
expected_duration = len(audio_data) / 32000
logger.info(f"PCM 音频数据,预期时长: {expected_duration:.2f}")
if expected_duration < 0.1:
logger.warning("音频时长太短,可能无法识别")
test_text = f"音频时长太短({expected_duration:.2f}秒),请说话时间长一些"
from ..response import success_response
return success_response({"text": test_text})
# 检查 DashScope 配置
if not settings.DASHSCOPE_API_KEY:
logger.error("未配置 DASHSCOPE_API_KEY")
test_text = f"ASR 未配置,收到 {expected_duration:.1f}秒 音频"
from ..response import success_response
return success_response({"text": test_text})
# 设置 API Key
dashscope.api_key = settings.DASHSCOPE_API_KEY
# 使用 DashScope 进行批量 ASR
logger.info("开始调用 DashScope ASR...")
try:
from dashscope.audio.asr import Transcription
from ..oss_utils import upload_audio_file, delete_audio_file, test_oss_connection
# 首先测试 OSS 连接
logger.info("测试 OSS 连接...")
if not test_oss_connection():
# OSS 连接失败,使用临时方案
logger.warning("OSS 连接失败,使用临时测试方案")
test_text = f"OSS 暂不可用,但成功接收到 {expected_duration:.1f}{audio_format.upper()} 音频数据({len(audio_data)} 字节)"
from ..response import success_response
return success_response({"text": test_text})
logger.info("OSS 连接测试通过")
# 上传音频文件到 OSS
logger.info(f"上传 {audio_format.upper()} 音频到 OSS...")
file_url = upload_audio_file(audio_data, audio_format)
logger.info(f"音频文件上传成功: {file_url}")
# 调用 DashScope ASR
try:
logger.info("调用 DashScope Transcription API...")
logger.info(f"使用文件 URL: {file_url}")
task_response = Transcription.async_call(
model='paraformer-v2',
file_urls=[file_url],
parameters={
'format': audio_format,
'sample_rate': 16000,
'enable_words': False
}
)
logger.info(f"ASR 任务响应: status_code={task_response.status_code}")
if task_response.status_code != 200:
error_msg = getattr(task_response, 'message', 'Unknown error')
logger.error(f"ASR 任务创建失败: {error_msg}")
raise Exception(f"ASR 任务创建失败: {error_msg}")
task_id = task_response.output.task_id
logger.info(f"ASR 任务已创建: {task_id}")
# 等待识别完成
logger.info("等待 ASR 识别完成...")
import time
max_wait_time = 30
start_time = time.time()
transcribe_response = None
try:
import threading
import queue
result_queue = queue.Queue()
exception_queue = queue.Queue()
def wait_for_result():
try:
result = Transcription.wait(task=task_id)
result_queue.put(result)
except Exception as e:
exception_queue.put(e)
# 启动等待线程
wait_thread = threading.Thread(target=wait_for_result)
wait_thread.daemon = True
wait_thread.start()
# 轮询检查结果或超时
while time.time() - start_time < max_wait_time:
try:
transcribe_response = result_queue.get_nowait()
logger.info("ASR 任务完成")
break
except queue.Empty:
pass
try:
exception = exception_queue.get_nowait()
logger.error(f"ASR 等待过程中出错: {exception}")
raise exception
except queue.Empty:
pass
elapsed = time.time() - start_time
logger.info(f"ASR 任务仍在处理中... 已等待 {elapsed:.1f}")
time.sleep(2)
if transcribe_response is None:
logger.error(f"ASR 任务超时({max_wait_time}秒)")
from ..response import success_response
return success_response({"text": f"语音识别处理时间较长,请稍后重试(音频时长: {expected_duration:.1f}秒)"})
except Exception as wait_error:
logger.error(f"ASR 等待过程中出错: {wait_error}")
from ..response import success_response
return success_response({"text": f"语音识别服务暂时不可用,请稍后重试"})
logger.info(f"ASR 识别响应: status_code={transcribe_response.status_code}")
if transcribe_response.status_code != 200:
error_msg = getattr(transcribe_response, 'message', 'Unknown error')
logger.error(f"ASR 识别失败: {error_msg}")
raise Exception(f"ASR 识别失败: {error_msg}")
# 检查任务状态
result = transcribe_response.output
logger.info(f"ASR 任务状态: {result.task_status}")
if result.task_status == "SUCCEEDED":
logger.info("ASR 识别成功,开始解析结果...")
elif result.task_status == "FAILED":
error_code = getattr(result, 'code', 'Unknown')
error_message = getattr(result, 'message', 'Unknown error')
logger.error(f"ASR 任务失败: {error_code} - {error_message}")
if error_code == "SUCCESS_WITH_NO_VALID_FRAGMENT":
user_message = "音频中未检测到有效语音,请确保录音时有说话内容"
elif error_code == "DECODE_ERROR":
user_message = "音频格式解码失败,请检查录音设置"
logger.error("音频解码失败 - 可能的原因:")
logger.error("1. 音频格式不正确或损坏")
logger.error("2. 编码参数不匹配建议16kHz, 单声道, 64kbps")
logger.error("3. 文件头信息缺失或错误")
elif error_code == "FILE_DOWNLOAD_FAILED":
user_message = "无法下载音频文件,请检查网络连接"
elif error_code == "AUDIO_FORMAT_UNSUPPORTED":
user_message = "音频格式不支持,请使用标准格式录音"
else:
user_message = f"语音识别失败: {error_message}"
from ..response import success_response
return success_response({"text": user_message})
else:
logger.warning(f"ASR 任务状态未知: {result.task_status}")
from ..response import success_response
return success_response({"text": f"语音识别状态异常: {result.task_status}"})
# 解析识别结果
text_result = ""
if hasattr(result, 'results') and result.results:
logger.info(f"找到 results 字段,长度: {len(result.results)}")
for i, item in enumerate(result.results):
if isinstance(item, dict) and 'transcription_url' in item and item['transcription_url']:
transcription_url = item['transcription_url']
logger.info(f"找到 transcription_url: {transcription_url}")
try:
import requests
response = requests.get(transcription_url, timeout=10)
if response.status_code == 200:
transcription_data = response.json()
logger.info(f"转录数据: {transcription_data}")
if 'transcripts' in transcription_data:
for transcript in transcription_data['transcripts']:
if 'text' in transcript:
text_result += transcript['text'] + " "
logger.info(f"提取转录文本: {transcript['text']}")
if text_result.strip():
break
except Exception as e:
logger.error(f"处理 transcription_url 失败: {e}")
text_result = text_result.strip()
if not text_result:
logger.warning("ASR 未识别到文本内容")
text_result = f"未识别到语音内容({expected_duration:.1f}秒音频)"
logger.info(f"最终 ASR 识别结果: {text_result}")
from ..response import success_response
return success_response({"text": text_result})
finally:
# 清理 OSS 上的临时文件
try:
delete_audio_file(file_url)
logger.info("OSS 临时文件已清理")
except Exception as e:
logger.warning(f"清理 OSS 文件失败: {e}")
except Exception as asr_error:
logger.error(f"DashScope ASR 调用失败: {asr_error}", exc_info=True)
error_msg = str(asr_error)
if "OSS" in error_msg:
test_text = f"OSS 配置问题,收到 {expected_duration:.1f}秒 音频"
elif "Transcription" in error_msg:
test_text = f"ASR 服务异常,收到 {expected_duration:.1f}秒 音频"
else:
test_text = f"ASR 处理失败,收到 {expected_duration:.1f}秒 音频"
logger.info(f"返回备用文本: {test_text}")
from ..response import success_response
return success_response({"text": test_text})
except HTTPException:
raise
except Exception as e:
logger.error(f"JSON ASR 处理错误: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"ASR 处理失败: {str(e)}")
@router.post("/call/batch_asr")
async def batch_asr(
audio: UploadFile = File(...),
user: AuthedUser = Depends(get_current_user)
):
"""批量 ASR接收完整音频文件并返回识别结果"""
try:
# 读取音频数据
audio_data = await audio.read()
logger.info(f"收到音频文件,大小: {len(audio_data)} 字节,文件名: {audio.filename}")
# 检查音频数据是否为空
if not audio_data:
logger.error("音频数据为空")
raise HTTPException(status_code=400, detail="音频数据为空")
# 计算预期的音频时长
if audio.filename and audio.filename.lower().endswith('.mp3'):
# MP3 文件,无法直接计算时长,跳过时长检查
expected_duration = len(audio_data) / 16000 # 粗略估算
logger.info(f"MP3 音频文件,预估时长: {expected_duration:.2f}")
else:
# PCM 格式16kHz 单声道 16bit每秒需要 32000 字节
expected_duration = len(audio_data) / 32000
logger.info(f"PCM 音频文件,预期时长: {expected_duration:.2f}")
if expected_duration < 0.1:
logger.warning("音频时长太短,可能无法识别")
test_text = f"音频时长太短({expected_duration:.2f}秒),请说话时间长一些"
from ..response import success_response
return success_response({"text": test_text})
# 检查 DashScope 配置
if not settings.DASHSCOPE_API_KEY:
logger.error("未配置 DASHSCOPE_API_KEY")
test_text = f"ASR 未配置,收到 {expected_duration:.1f}秒 音频"
from ..response import success_response
return success_response({"text": test_text})
# 设置 API Key
dashscope.api_key = settings.DASHSCOPE_API_KEY
# 使用 DashScope 进行批量 ASR
logger.info("开始调用 DashScope ASR...")
try:
import wave
import tempfile
import os
from dashscope.audio.asr import Transcription
from ..oss_utils import upload_audio_file, delete_audio_file, test_oss_connection
# 首先测试 OSS 连接
logger.info("测试 OSS 连接...")
if not test_oss_connection():
# OSS 连接失败,使用临时方案
logger.warning("OSS 连接失败,使用临时测试方案")
test_text = f"OSS 暂不可用,但成功接收到 {expected_duration:.1f}秒 MP3 音频文件({len(audio_data)} 字节)"
from ..response import success_response
return success_response({"text": test_text})
logger.info("OSS 连接测试通过")
# 检测音频格式并处理
if audio.filename and audio.filename.lower().endswith('.mp3'):
# MP3 文件,直接上传
logger.info("检测到 MP3 格式,直接上传")
file_url = upload_audio_file(audio_data, "mp3")
logger.info(f"MP3 文件上传成功: {file_url}")
else:
# PCM 数据,转换为 WAV 格式
logger.info("检测到 PCM 格式,转换为 WAV")
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
# 创建 WAV 文件
with wave.open(temp_file.name, 'wb') as wav_file:
wav_file.setnchannels(1) # 单声道
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(16000) # 16kHz
wav_file.writeframes(audio_data)
temp_file_path = temp_file.name
try:
# 读取 WAV 文件数据
with open(temp_file_path, 'rb') as f:
wav_data = f.read()
# 上传 WAV 文件到 OSS 并获取公网 URL
logger.info("上传 WAV 文件到 OSS...")
file_url = upload_audio_file(wav_data, "wav")
logger.info(f"WAV 文件上传成功: {file_url}")
finally:
# 清理本地临时文件
try:
os.unlink(temp_file_path)
except Exception as e:
logger.warning(f"清理临时文件失败: {e}")
# 调用 DashScope ASR
try:
logger.info("调用 DashScope Transcription API...")
logger.info(f"使用文件 URL: {file_url}")
task_response = Transcription.async_call(
model='paraformer-v2',
file_urls=[file_url],
parameters={
'format': 'mp3',
'sample_rate': 16000,
'enable_words': False
}
)
logger.info(f"ASR 任务响应: status_code={task_response.status_code}")
logger.info(f"ASR 任务响应完整内容: {task_response}")
if hasattr(task_response, 'message'):
logger.info(f"ASR 任务消息: {task_response.message}")
if hasattr(task_response, 'output'):
logger.info(f"ASR 任务输出: {task_response.output}")
if task_response.status_code != 200:
error_msg = getattr(task_response, 'message', 'Unknown error')
logger.error(f"ASR 任务创建失败: {error_msg}")
# 检查具体错误类型
if hasattr(task_response, 'output') and task_response.output:
logger.error(f"错误详情: {task_response.output}")
raise Exception(f"ASR 任务创建失败: {error_msg}")
task_id = task_response.output.task_id
logger.info(f"ASR 任务已创建: {task_id}")
# 等待识别完成,使用更智能的轮询策略
logger.info("等待 ASR 识别完成...")
import time
# 设置最大等待时间45秒给前端留足够缓冲
max_wait_time = 45
start_time = time.time()
transcribe_response = None
try:
# 使用一个循环来检查超时但仍然使用原始的wait方法
logger.info(f"开始等待ASR任务完成最大等待时间: {max_wait_time}")
# 在单独的线程中执行wait操作这样可以控制超时
import threading
import queue
result_queue = queue.Queue()
exception_queue = queue.Queue()
def wait_for_result():
try:
result = Transcription.wait(task=task_id)
result_queue.put(result)
except Exception as e:
exception_queue.put(e)
# 启动等待线程
wait_thread = threading.Thread(target=wait_for_result)
wait_thread.daemon = True
wait_thread.start()
# 轮询检查结果或超时
while time.time() - start_time < max_wait_time:
# 检查是否有结果
try:
transcribe_response = result_queue.get_nowait()
logger.info("ASR 任务完成")
break
except queue.Empty:
pass
# 检查是否有异常
try:
exception = exception_queue.get_nowait()
logger.error(f"ASR 等待过程中出错: {exception}")
raise exception
except queue.Empty:
pass
# 显示进度
elapsed = time.time() - start_time
logger.info(f"ASR 任务仍在处理中... 已等待 {elapsed:.1f}")
time.sleep(3) # 每3秒检查一次
# 检查是否超时
if transcribe_response is None:
logger.error(f"ASR 任务超时({max_wait_time}任务ID: {task_id}")
# 返回一个友好的超时消息而不是抛出异常
from ..response import success_response
return success_response({"text": f"语音识别处理时间较长,请稍后重试(音频时长: {expected_duration:.1f}秒)"})
except Exception as wait_error:
logger.error(f"ASR 等待过程中出错: {wait_error}")
# 返回友好的错误消息而不是抛出异常
from ..response import success_response
return success_response({"text": f"语音识别服务暂时不可用,请稍后重试"})
logger.info(f"ASR 识别响应: status_code={transcribe_response.status_code}")
if hasattr(transcribe_response, 'message'):
logger.info(f"ASR 识别消息: {transcribe_response.message}")
if transcribe_response.status_code != 200:
error_msg = getattr(transcribe_response, 'message', 'Unknown error')
logger.error(f"ASR 识别失败: {error_msg}")
raise Exception(f"ASR 识别失败: {error_msg}")
# 检查任务状态
result = transcribe_response.output
logger.info(f"ASR 任务状态: {result.task_status}")
if result.task_status == "SUCCEEDED":
logger.info("ASR 识别成功,开始解析结果...")
elif result.task_status == "FAILED":
error_code = getattr(result, 'code', 'Unknown')
error_message = getattr(result, 'message', 'Unknown error')
logger.error(f"ASR 任务失败: {error_code} - {error_message}")
# 提供更友好的错误信息
if error_code == "FILE_DOWNLOAD_FAILED":
user_message = "无法下载音频文件,请检查网络连接"
elif error_code == "SUCCESS_WITH_NO_VALID_FRAGMENT":
user_message = "音频中未检测到有效语音,请确保录音时有说话内容"
elif error_code == "AUDIO_FORMAT_UNSUPPORTED":
user_message = "音频格式不支持,请使用标准格式录音"
else:
user_message = f"语音识别失败: {error_message}"
from ..response import success_response
return success_response({"text": user_message})
else:
logger.warning(f"ASR 任务状态未知: {result.task_status}")
from ..response import success_response
return success_response({"text": f"语音识别状态异常: {result.task_status}"})
# 解析识别结果
logger.info(f"ASR 识别结果类型: {type(result)}")
logger.info(f"ASR 识别完成,结果: {result}")
# 提取文本内容
text_result = ""
logger.info(f"开始解析 ASR 结果...")
logger.info(f"result 对象类型: {type(result)}")
# 打印完整的结果对象以便调试
try:
result_dict = vars(result) if hasattr(result, '__dict__') else result
logger.info(f"完整 result 对象: {result_dict}")
except Exception as e:
logger.info(f"无法序列化 result 对象: {e}")
logger.info(f"result 对象字符串: {str(result)}")
# 尝试多种方式提取文本
if hasattr(result, 'results') and result.results:
logger.info(f"找到 results 字段,长度: {len(result.results)}")
for i, item in enumerate(result.results):
logger.info(f"处理 result[{i}]: {type(item)}")
# 打印每个 item 的详细信息
try:
if hasattr(item, '__dict__'):
item_dict = vars(item)
logger.info(f"result[{i}] 对象内容: {item_dict}")
else:
logger.info(f"result[{i}] 内容: {item}")
except Exception as e:
logger.info(f"无法序列化 result[{i}]: {e}")
# 如果 item 是字典
if isinstance(item, dict):
logger.info(f"result[{i}] 是字典,键: {list(item.keys())}")
# 检查 transcription_urlDashScope 的实际返回格式)
if 'transcription_url' in item and item['transcription_url']:
transcription_url = item['transcription_url']
logger.info(f"找到 transcription_url: {transcription_url}")
try:
# 下载转录结果
import requests
response = requests.get(transcription_url, timeout=10)
if response.status_code == 200:
transcription_data = response.json()
logger.info(f"转录数据: {transcription_data}")
# 解析转录数据
if 'transcripts' in transcription_data:
for transcript in transcription_data['transcripts']:
if 'text' in transcript:
text_result += transcript['text'] + " "
logger.info(f"提取转录文本: {transcript['text']}")
elif 'text' in transcription_data:
text_result += transcription_data['text'] + " "
logger.info(f"提取直接文本: {transcription_data['text']}")
# 如果找到了文本,跳出循环
if text_result.strip():
break
else:
logger.error(f"下载转录结果失败: HTTP {response.status_code}")
except Exception as e:
logger.error(f"处理 transcription_url 失败: {e}")
# 检查各种可能的字段
elif 'transcription' in item and item['transcription']:
transcription = item['transcription']
logger.info(f"找到字段 transcription: {transcription}")
if isinstance(transcription, str):
text_result += transcription + " "
logger.info(f"提取字符串文本: {transcription}")
elif isinstance(transcription, dict):
# 检查嵌套的文本字段
for text_key in ['text', 'content', 'transcript']:
if text_key in transcription:
text_result += str(transcription[text_key]) + " "
logger.info(f"提取嵌套文本: {transcription[text_key]}")
break
# 检查直接的 text 字段
elif 'text' in item and item['text']:
text_result += item['text'] + " "
logger.info(f"提取 item 字典文本: {item['text']}")
# 如果 item 是对象
else:
# 检查各种可能的属性
for attr in ['transcription', 'text', 'transcript', 'content']:
if hasattr(item, attr):
value = getattr(item, attr)
if value:
logger.info(f"找到属性 {attr}: {value}")
if isinstance(value, str):
text_result += value + " "
logger.info(f"提取属性文本: {value}")
break
# 如果 results 中没有找到文本,检查顶级字段
if not text_result:
logger.info("未从 results 提取到文本,检查顶级字段")
for attr in ['text', 'transcription', 'transcript', 'content']:
if hasattr(result, attr):
value = getattr(result, attr)
if value:
logger.info(f"找到顶级属性 {attr}: {value}")
text_result = str(value)
break
# 如果还是没有找到,尝试从原始响应中提取
if not text_result:
logger.warning("所有标准方法都未能提取到文本")
logger.info("尝试从原始响应中查找文本...")
# 将整个结果转换为字符串并查找可能的文本
result_str = str(result)
logger.info(f"结果字符串: {result_str}")
# 简单的文本提取逻辑
if "text" in result_str.lower():
logger.info("在结果字符串中发现 'text' 关键字")
# 这里可以添加更复杂的文本提取逻辑
text_result = "检测到语音内容,但解析格式需要调整"
else:
text_result = "语音识别成功,但未能解析文本内容"
# 清理文本
text_result = text_result.strip()
if not text_result:
logger.warning("ASR 未识别到文本内容")
logger.info(f"完整的 result 对象: {vars(result) if hasattr(result, '__dict__') else result}")
text_result = f"未识别到语音内容({expected_duration:.1f}秒音频)"
logger.info(f"最终 ASR 识别结果: {text_result}")
from ..response import success_response
return success_response({"text": text_result})
finally:
# 清理 OSS 上的临时文件
try:
delete_audio_file(file_url)
logger.info("OSS 临时文件已清理")
except Exception as e:
logger.warning(f"清理 OSS 文件失败: {e}")
except Exception as asr_error:
logger.error(f"DashScope ASR 调用失败: {asr_error}", exc_info=True)
# 如果 ASR 失败,返回有意义的测试文本
error_msg = str(asr_error)
if "OSS" in error_msg:
test_text = f"OSS 配置问题,收到 {expected_duration:.1f}秒 音频"
elif "Transcription" in error_msg:
test_text = f"ASR 服务异常,收到 {expected_duration:.1f}秒 音频"
else:
test_text = f"ASR 处理失败,收到 {expected_duration:.1f}秒 音频"
logger.info(f"返回备用文本: {test_text}")
from ..response import success_response
return success_response({"text": test_text})
except HTTPException:
raise
except Exception as e:
logger.error(f"ASR 处理错误: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"ASR 处理失败: {str(e)}")
@router.websocket("/call")
async def voice_call(websocket: WebSocket):
try:
user = await authenticate_websocket(websocket)
except HTTPException as exc:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
ptt_param = (websocket.query_params.get("ptt") or "").strip().lower()
require_ptt = settings.VOICE_CALL_REQUIRE_PTT or ptt_param in ("1", "true", "yes", "on")
session = VoiceCallSession(websocket, user, require_ptt=require_ptt)
try:
await session.start()
except HTTPException as exc:
try:
await websocket.accept()
await websocket.send_text(json.dumps({"type": "error", "msg": exc.detail}))
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
except Exception:
pass
return
try:
while True:
msg = await websocket.receive()
if "bytes" in msg and msg["bytes"] is not None:
audio_data = msg["bytes"]
logger.info(f"📨 收到二进制消息,大小: {len(audio_data)} 字节")
await session.feed_audio(audio_data)
elif "text" in msg and msg["text"]:
# 简单心跳/信令
text = msg["text"].strip()
logger.info(f"📨 收到文本消息: {text}")
lower_text = text.lower()
if lower_text in ("mic_on", "ptt_on"):
await session.set_mic_enabled(True)
elif lower_text in ("mic_off", "ptt_off"):
await session.set_mic_enabled(False, flush=True)
elif text == "ping":
await websocket.send_text("pong")
elif text in ("end", "stop", "flush"):
logger.info("📥 收到结束信号,调用 finalize_asr")
session.finalize_asr()
await session.send_signal({"type": "info", "msg": "ASR stopped manually"})
else:
await session.send_signal({"type": "info", "msg": "文本消息已忽略"})
if msg.get("type") == "websocket.disconnect":
break
except WebSocketDisconnect:
pass
finally:
await session.close()
@router.post("/call/conversation")
async def voice_conversation(
request: dict,
user: AuthedUser = Depends(get_current_user)
):
"""
完整的语音对话流程:
1. 接收音频数据base64
2. ASR 识别为文字
3. LLM 生成回复
4. TTS 合成语音
5. 返回语音数据base64
"""
try:
# 1. 接收并解码音频数据
if 'audio_data' not in request:
raise HTTPException(status_code=400, detail="缺少 audio_data 字段")
audio_base64 = request['audio_data']
audio_format = request.get('format', 'wav')
logger.info(f"收到语音对话请求,用户: {user.id}, 格式: {audio_format}")
# 解码音频
import base64
audio_data = base64.b64decode(audio_base64)
logger.info(f"音频数据大小: {len(audio_data)} 字节")
# 2. ASR 识别
logger.info("开始 ASR 识别...")
from dashscope.audio.asr import Transcription
from ..oss_utils import upload_audio_file, delete_audio_file
# 上传到 OSS
file_url = upload_audio_file(audio_data, audio_format)
logger.info(f"音频已上传: {file_url}")
try:
# 调用 ASR
task_response = Transcription.async_call(
model='paraformer-v2',
file_urls=[file_url],
parameters={
'format': audio_format,
'sample_rate': 16000,
'enable_words': False
}
)
if task_response.status_code != 200:
raise Exception(f"ASR 任务创建失败")
task_id = task_response.output.task_id
logger.info(f"ASR 任务创建: {task_id}")
# 等待识别结果
import time
max_wait = 30
start_time = time.time()
user_text = None
while time.time() - start_time < max_wait:
result = Transcription.wait(task=task_id)
if result.status_code == 200:
if result.output.task_status == "SUCCEEDED":
# 解析识别结果
if hasattr(result.output, 'results') and result.output.results:
for item in result.output.results:
if isinstance(item, dict) and 'transcription_url' in item:
import requests
resp = requests.get(item['transcription_url'], timeout=10)
if resp.status_code == 200:
data = resp.json()
if 'transcripts' in data:
for transcript in data['transcripts']:
if 'text' in transcript:
user_text = transcript['text'].strip()
break
if user_text:
break
break
elif result.output.task_status == "FAILED":
error_code = getattr(result.output, 'code', 'Unknown')
logger.error(f"ASR 失败: {error_code}")
break
time.sleep(2)
if not user_text:
logger.warning("ASR 未识别到文本")
from ..response import success_response
return success_response({
"user_text": "",
"ai_text": "抱歉,我没有听清楚,请再说一遍",
"audio_data": None
})
logger.info(f"ASR 识别结果: {user_text}")
finally:
# 清理 OSS 文件
try:
delete_audio_file(file_url)
except:
pass
# 3. LLM 生成回复
logger.info("开始 LLM 对话生成...")
# 获取用户的恋人信息
db = SessionLocal()
try:
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
# 构建系统提示
system_prompt = f"你是用户 {user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天。"
if lover and lover.personality_prompt:
system_prompt += f"\n人格设定:{lover.personality_prompt}"
# 构建对话历史
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_text}
]
# 调用 LLM
from ..llm import chat_completion
llm_result = chat_completion(messages)
ai_text = llm_result.content
logger.info(f"LLM 回复: {ai_text}")
finally:
db.close()
# 4. TTS 合成语音
logger.info("开始 TTS 语音合成...")
# 清理文本(去除 Markdown 等)
clean_text = re.sub(r"\*\*(.*?)\*\*", r"\1", ai_text)
clean_text = re.sub(r"`([^`]*)`", r"\1", clean_text)
clean_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", clean_text)
clean_text = re.sub(r"\*[^\*]{0,80}\*", "", clean_text)
clean_text = re.sub(r"[~]+", "", clean_text)
clean_text = clean_text.replace("*", "")
clean_text = re.sub(r"\s+", " ", clean_text).strip()
# 获取音色配置
db = SessionLocal()
try:
voice_code = None
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if lover and lover.voice_id:
voice = db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first()
if voice and voice.voice_code:
voice_code = voice.voice_code
if not voice_code:
# 使用默认音色
gender = "female" if (user.gender or 0) == 1 else "male"
voice = db.query(VoiceLibrary).filter(
VoiceLibrary.gender == gender,
VoiceLibrary.is_default.is_(True)
).first()
if voice and voice.voice_code:
voice_code = voice.voice_code
else:
voice_code = settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
finally:
db.close()
# 调用 TTS
model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
audio_format_enum = AudioFormat.MP3_22050HZ_MONO_256KBPS
audio_bytes, _ = synthesize(
clean_text,
model=model,
voice=voice_code,
audio_format=audio_format_enum
)
logger.info(f"TTS 合成完成,音频大小: {len(audio_bytes)} 字节")
# 5. 返回结果
audio_base64_result = base64.b64encode(audio_bytes).decode('utf-8')
from ..response import success_response
return success_response({
"user_text": user_text,
"ai_text": ai_text,
"audio_data": audio_base64_result,
"audio_format": "mp3"
})
except HTTPException:
raise
except Exception as e:
logger.error(f"语音对话处理失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"语音对话处理失败: {str(e)}")