Ai_GirlFriend/lover/routers/voice_call.py
2026-01-31 19:15:41 +08:00

588 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import logging
import re
import time
from typing import List, Optional
import requests
import dashscope
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status
from fastapi.websockets import WebSocketState
from ..config import settings
from ..deps import AuthedUser, _fetch_user_from_php
from ..llm import chat_completion_stream
from ..tts import synthesize
from ..db import SessionLocal
from ..models import Lover, VoiceLibrary
try:
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
except Exception: # dashscope 未安装时提供兜底
Recognition = None
RecognitionCallback = object # type: ignore
RecognitionResult = object # type: ignore
try:
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer, ResultCallback
except Exception:
AudioFormat = None # type: ignore
SpeechSynthesizer = None # type: ignore
ResultCallback = object # type: ignore
router = APIRouter(prefix="/voice", tags=["voice"])
logger = logging.getLogger("voice_call")
logger.setLevel(logging.INFO)
END_OF_TTS = "<<VOICE_CALL_TTS_END>>"
class WSRecognitionCallback(RecognitionCallback): # type: ignore[misc]
"""ASR 回调,将句子级结果推入会话队列。"""
def __init__(self, session: "VoiceCallSession"):
super().__init__()
self.session = session
self._last_text: Optional[str] = None
def on_open(self) -> None:
logger.info("ASR connection opened")
def on_complete(self) -> None:
logger.info("ASR complete")
if self._last_text:
# 将最后的部分作为一句结束,防止没有 end 标记时丢失
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on complete: %s", self._last_text)
self._last_text = None
def on_error(self, result: RecognitionResult) -> None:
logger.error("ASR error: %s", getattr(result, "message", None) or result)
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on error: %s", self._last_text)
self._last_text = None
def on_close(self) -> None:
logger.info("ASR closed")
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on close: %s", self._last_text)
self._last_text = None
def on_event(self, result: RecognitionResult) -> None:
sentence = result.get_sentence()
if not sentence:
return
sentences = sentence if isinstance(sentence, list) else [sentence]
for sent in sentences:
text = sent.get("text") if isinstance(sent, dict) else None
if not text:
continue
is_end = False
if isinstance(sent, dict):
is_end = (
bool(sent.get("is_sentence_end"))
or bool(sent.get("sentence_end"))
or RecognitionResult.is_sentence_end(sent)
)
if is_end:
self.session._schedule(self.session.handle_sentence(text))
self._last_text = None
else:
self.session._schedule(self.session.send_signal({"type": "partial_asr", "text": text}))
self._last_text = text
logger.info("ASR event end=%s sentence=%s", is_end, sent)
async def authenticate_websocket(websocket: WebSocket) -> AuthedUser:
"""复用 HTTP 鉴权逻辑Authorization / X-Token / x_user_id调试"""
headers = websocket.headers
token = None
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = headers.get("x-token")
# 支持 query 携带
if not token:
token = websocket.query_params.get("token")
x_user_id = websocket.query_params.get("x_user_id")
if token:
payload = _fetch_user_from_php(token)
user_id = payload.get("id") or payload.get("user_id")
reg_step = payload.get("reg_step") or payload.get("stage") or 1
gender = payload.get("gender") or 0
nickname = payload.get("nickname") or payload.get("username") or ""
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户中心缺少用户ID")
return AuthedUser(
id=user_id,
reg_step=reg_step,
gender=gender,
nickname=nickname,
token=token,
)
if x_user_id is not None:
try:
uid = int(x_user_id)
except Exception:
uid = None
if uid is not None:
return AuthedUser(id=uid, reg_step=2, gender=0, nickname="debug-user", token="")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或未授权")
class VoiceCallSession:
def __init__(self, websocket: WebSocket, user: AuthedUser, require_ptt: bool = False):
self.websocket = websocket
self.user = user
self.require_ptt = require_ptt
self.mic_enabled = not require_ptt
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.asr_to_llm: asyncio.Queue[str] = asyncio.Queue()
self.llm_to_tts: asyncio.Queue[str] = asyncio.Queue()
self.is_speaking = False
self.lover: Optional[Lover] = None
self.db = SessionLocal()
self.voice_code: Optional[str] = None
self.history: List[dict] = [
{
"role": "system",
"content": self._compose_system_prompt(),
}
]
self.llm_task: Optional[asyncio.Task] = None
self.tts_task: Optional[asyncio.Task] = None
self.tts_stream_task: Optional[asyncio.Task] = None
self.silence_task: Optional[asyncio.Task] = None
self.cancel_event = asyncio.Event()
self.recognition: Optional[Recognition] = None
self.idle_task: Optional[asyncio.Task] = None
self.last_activity = time.time()
self.last_voice_activity = time.time()
self.has_voice_input = False
self.last_interrupt_time = 0.0
self.tts_first_chunk = True
async def start(self):
await self.websocket.accept()
self.loop = asyncio.get_running_loop()
# 预加载恋人与音色,避免在流式环节阻塞事件循环
self._prepare_profile()
# 启动 ASR
self._start_asr()
# 启动 LLM/TTS 后台任务
self.llm_task = asyncio.create_task(self._process_llm_loop())
self.tts_task = asyncio.create_task(self._process_tts_loop())
self.idle_task = asyncio.create_task(self._idle_watchdog())
self.silence_task = asyncio.create_task(self._silence_watchdog())
await self.send_signal({"type": "ready"})
if self.require_ptt:
await self.send_signal({"type": "info", "msg": "ptt_enabled"})
def _start_asr(self):
if Recognition is None:
raise HTTPException(status_code=500, detail="未安装 dashscope无法启动实时 ASR")
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
dashscope.api_key = settings.DASHSCOPE_API_KEY
callback = WSRecognitionCallback(self)
self.recognition = Recognition(
model=settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
format="pcm",
sample_rate=settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
api_key=settings.DASHSCOPE_API_KEY,
callback=callback,
)
logger.info(
"ASR started model=%s sample_rate=%s",
settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
)
self.recognition.start()
async def handle_sentence(self, text: str):
# 回合制AI 说话时忽略用户语音,提示稍后再说
if self.is_speaking:
await self.send_signal({"type": "info", "msg": "请等待 AI 说完再讲话"})
return
logger.info("Handle sentence: %s", text)
await self.asr_to_llm.put(text)
async def _process_llm_loop(self):
while True:
text = await self.asr_to_llm.get()
self.cancel_event.clear()
try:
await self._stream_llm(text)
except asyncio.CancelledError:
break
except Exception as exc:
logger.exception("LLM error", exc_info=exc)
await self.send_signal({"type": "error", "msg": "LLM 生成失败"})
self.is_speaking = False
async def _stream_llm(self, text: str):
self.history.append({"role": "user", "content": text})
# 控制历史长度
if len(self.history) > settings.VOICE_CALL_MAX_HISTORY:
self.history = self.history[-settings.VOICE_CALL_MAX_HISTORY :]
stream = chat_completion_stream(self.history)
self.is_speaking = True
self.tts_first_chunk = True
buffer = []
for chunk in stream:
if self.cancel_event.is_set():
break
buffer.append(chunk)
await self.llm_to_tts.put(chunk)
if not self.cancel_event.is_set():
await self.llm_to_tts.put(END_OF_TTS)
full_reply = "".join(buffer)
self.history.append({"role": "assistant", "content": full_reply})
if full_reply:
# 下行完整文本,便于前端展示/调试
await self.send_signal({"type": "reply_text", "text": full_reply})
else:
self.is_speaking = False
async def _process_tts_loop(self):
temp_buffer = []
punctuations = set(",。?!,.?!;")
while True:
token = await self.llm_to_tts.get()
if self.cancel_event.is_set():
temp_buffer = []
self.tts_first_chunk = True
continue
if token == END_OF_TTS:
# 将残余缓冲送出
if temp_buffer:
text_chunk = "".join(temp_buffer)
temp_buffer = []
clean_text = self._clean_tts_text(text_chunk)
if clean_text:
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
continue
self.tts_first_chunk = True
self.is_speaking = False
await self.send_signal({"type": "reply_end"})
continue
temp_buffer.append(token)
last_char = token[-1] if token else ""
threshold = 8 if self.tts_first_chunk else 18
if last_char in punctuations or len("".join(temp_buffer)) >= threshold:
text_chunk = "".join(temp_buffer)
temp_buffer = []
self.tts_first_chunk = False
clean_text = self._clean_tts_text(text_chunk)
if not clean_text:
continue
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
# 不可达,但保留以防逻辑调整
async def _synthesize_stream(self, text: str):
"""
调用 cosyvoice v2 流式合成,逐 chunk 返回。
如流式不可用则回落一次性合成。
"""
model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
voice = self._pick_voice_code() or settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
fmt = settings.VOICE_CALL_TTS_FORMAT.lower() if settings.VOICE_CALL_TTS_FORMAT else "mp3"
audio_format = AudioFormat.MP3_22050HZ_MONO_256KBPS if fmt == "mp3" else AudioFormat.PCM_16000HZ_MONO
# 直接同步合成,避免流式阻塞
audio_bytes, _fmt_name = synthesize(text, model=model, voice=voice, audio_format=audio_format) # type: ignore[arg-type]
yield audio_bytes
async def feed_audio(self, data: bytes):
if self.require_ptt and not self.mic_enabled:
# PTT 模式下未按住说话时丢弃音频
self._touch()
return
# 若之前 stop 过,则懒启动
if not (self.recognition and getattr(self.recognition, "_running", False)):
try:
self._start_asr()
except Exception as exc:
logger.error("ASR restart failed: %s", exc)
return
if self.recognition:
self.recognition.send_audio_frame(data)
logger.debug("recv audio chunk bytes=%s", len(data))
peak = self._peak_pcm16(data)
now = time.time()
if peak > 300: # 只用于活跃检测,不再触发打断
self.last_voice_activity = now
self.has_voice_input = True
self._touch()
def finalize_asr(self):
"""主动停止 ASR促使返回最终结果。"""
try:
if self.recognition:
self.recognition.stop()
logger.info("ASR stop requested manually")
except Exception as exc:
logger.warning("ASR stop failed: %s", exc)
async def set_mic_enabled(self, enabled: bool, flush: bool = False):
if not self.require_ptt:
return
self.mic_enabled = enabled
await self.send_signal({"type": "info", "msg": "mic_on" if enabled else "mic_off"})
if not enabled and flush:
self.finalize_asr()
def _schedule(self, coro):
if self.loop:
self.loop.call_soon_threadsafe(asyncio.create_task, coro)
def _pick_voice_code(self) -> Optional[str]:
"""根据恋人配置或默认音色选择 voice_code。"""
if self.voice_code:
return self.voice_code
self._prepare_profile()
return self.voice_code
async def _interrupt(self):
self.cancel_event.set()
# 清空队列
while not self.llm_to_tts.empty():
try:
self.llm_to_tts.get_nowait()
except Exception:
break
await self.send_signal({"type": "interrupt", "code": "interrupted", "msg": "AI 打断,停止播放"})
self.is_speaking = False
self.last_interrupt_time = time.time()
async def close(self):
if self.db:
try:
self.db.close()
except Exception:
pass
if self.recognition:
try:
self.recognition.stop()
except Exception:
pass
if self.llm_task:
self.llm_task.cancel()
if self.tts_task:
self.tts_task.cancel()
if self.tts_stream_task:
self.tts_stream_task.cancel()
if self.idle_task:
self.idle_task.cancel()
if self.silence_task:
self.silence_task.cancel()
if self.websocket.client_state == WebSocketState.CONNECTED:
await self.websocket.close()
async def send_signal(self, payload: dict):
if self.websocket.client_state != WebSocketState.CONNECTED:
return
try:
await self.websocket.send_text(json.dumps(payload, ensure_ascii=False))
self._touch()
except WebSocketDisconnect:
return
def _load_lover(self) -> Optional[Lover]:
if self.lover is not None:
return self.lover
try:
self.lover = self.db.query(Lover).filter(Lover.user_id == self.user.id).first()
except Exception as exc:
logger.warning("Load lover failed: %s", exc)
self.lover = None
return self.lover
def _compose_system_prompt(self) -> str:
parts = [
f"你是用户 {self.user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天,不要使用 Markdown 符号,不要输出表情、波浪线、星号或动作描述。",
"回复必须是对话内容,不要包含括号/星号/动作描写/舞台指令,不要用拟声词凑字数,保持简短自然的中文口语句子。",
"禁止涉政、违法、暴力、未成年相关内容。",
]
lover = self._load_lover()
if lover and lover.personality_prompt:
parts.append(f"人格设定:{lover.personality_prompt}")
return "\n".join(parts)
@staticmethod
def _clean_tts_text(text: str) -> str:
if not text:
return ""
# 去掉常见 Markdown/代码标记,保留文字内容
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
text = re.sub(r"`([^`]*)`", r"\1", text)
text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
text = re.sub(r"\*[^\*]{0,80}\*", "", text) # 去掉 *动作* 片段
text = re.sub(r"[~]+", "", text) # 去掉波浪线
text = text.replace("*", "")
text = re.sub(r"\s+", " ", text)
return text.strip()
def _prepare_profile(self) -> None:
"""预加载恋人和音色,避免在流式阶段阻塞事件循环。"""
try:
lover = self._load_lover()
if lover and lover.voice_id:
voice = self.db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first()
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
gender = None
if lover and lover.gender:
gender = lover.gender
if not gender:
gender = "female" if (self.user.gender or 0) == 1 else "male"
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender)
.order_by(VoiceLibrary.id.asc())
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
except Exception as exc:
logger.warning("Prepare profile failed: %s", exc)
def _touch(self):
self.last_activity = time.time()
async def _idle_watchdog(self):
timeout = settings.VOICE_CALL_IDLE_TIMEOUT or 0
if timeout <= 0:
return
try:
while True:
await asyncio.sleep(5)
if time.time() - self.last_activity > timeout:
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
async def _silence_watchdog(self):
"""长时间静默时关闭会话ASR 常驻不再因短静音 stop。"""
try:
while True:
await asyncio.sleep(1.0)
if time.time() - self.last_voice_activity > 60:
logger.info("Long silence, closing session")
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
@staticmethod
def _peak_pcm16(data: bytes) -> int:
"""快速估算 PCM 16bit 峰值幅度。"""
if not data:
return 0
view = memoryview(data)
# 每 2 字节一采样,取绝对值最大
max_val = 0
for i in range(0, len(view) - 1, 2):
sample = int.from_bytes(view[i : i + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > max_val:
max_val = sample
return max_val
@router.websocket("/call")
async def voice_call(websocket: WebSocket):
try:
user = await authenticate_websocket(websocket)
except HTTPException as exc:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
ptt_param = (websocket.query_params.get("ptt") or "").strip().lower()
require_ptt = settings.VOICE_CALL_REQUIRE_PTT or ptt_param in ("1", "true", "yes", "on")
session = VoiceCallSession(websocket, user, require_ptt=require_ptt)
try:
await session.start()
except HTTPException as exc:
try:
await websocket.accept()
await websocket.send_text(json.dumps({"type": "error", "msg": exc.detail}))
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
except Exception:
pass
return
try:
while True:
msg = await websocket.receive()
if "bytes" in msg and msg["bytes"] is not None:
await session.feed_audio(msg["bytes"])
elif "text" in msg and msg["text"]:
# 简单心跳/信令
text = msg["text"].strip()
lower_text = text.lower()
if lower_text in ("mic_on", "ptt_on"):
await session.set_mic_enabled(True)
elif lower_text in ("mic_off", "ptt_off"):
await session.set_mic_enabled(False, flush=True)
elif text == "ping":
await websocket.send_text("pong")
elif text in ("end", "stop", "flush"):
session.finalize_asr()
await session.send_signal({"type": "info", "msg": "ASR stopped manually"})
else:
await session.send_signal({"type": "info", "msg": "文本消息已忽略"})
if msg.get("type") == "websocket.disconnect":
break
except WebSocketDisconnect:
pass
finally:
await session.close()