Ai_GirlFriend/lover/routers/voice_call.py

619 lines
24 KiB
Python
Raw Normal View History

2026-01-31 19:15:41 +08:00
import asyncio
import json
import logging
import re
import time
from typing import List, Optional
import requests
import dashscope
2026-02-02 20:08:28 +08:00
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
2026-01-31 19:15:41 +08:00
from fastapi.websockets import WebSocketState
from ..config import settings
2026-02-02 20:08:28 +08:00
from ..deps import AuthedUser, get_current_user, _fetch_user_from_php
2026-01-31 19:15:41 +08:00
from ..llm import chat_completion_stream
from ..tts import synthesize
from ..db import SessionLocal
2026-02-02 20:08:28 +08:00
from ..models import Lover, VoiceLibrary, User
2026-01-31 19:15:41 +08:00
try:
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
except Exception: # dashscope 未安装时提供兜底
Recognition = None
RecognitionCallback = object # type: ignore
RecognitionResult = object # type: ignore
try:
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer, ResultCallback
except Exception:
AudioFormat = None # type: ignore
SpeechSynthesizer = None # type: ignore
ResultCallback = object # type: ignore
router = APIRouter(prefix="/voice", tags=["voice"])
logger = logging.getLogger("voice_call")
logger.setLevel(logging.INFO)
END_OF_TTS = "<<VOICE_CALL_TTS_END>>"
2026-02-02 20:08:28 +08:00
@router.get("/call/duration")
async def get_call_duration(user: AuthedUser = Depends(get_current_user)):
"""获取用户的语音通话时长配置"""
from ..db import SessionLocal
from ..models import User
from datetime import datetime
db = SessionLocal()
try:
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# 检查 VIP 状态vip_endtime 是 Unix 时间戳)
current_timestamp = int(datetime.utcnow().timestamp())
is_vip = user_row.vip_endtime and user_row.vip_endtime > current_timestamp
if is_vip:
duration = 0 # 0 表示无限制
else:
duration = 300000 # 普通用户 5 分钟
from ..response import success_response
return success_response({
"duration": duration,
"is_vip": is_vip
})
finally:
db.close()
2026-01-31 19:15:41 +08:00
class WSRecognitionCallback(RecognitionCallback): # type: ignore[misc]
"""ASR 回调,将句子级结果推入会话队列。"""
def __init__(self, session: "VoiceCallSession"):
super().__init__()
self.session = session
self._last_text: Optional[str] = None
def on_open(self) -> None:
logger.info("ASR connection opened")
def on_complete(self) -> None:
logger.info("ASR complete")
if self._last_text:
# 将最后的部分作为一句结束,防止没有 end 标记时丢失
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on complete: %s", self._last_text)
self._last_text = None
def on_error(self, result: RecognitionResult) -> None:
logger.error("ASR error: %s", getattr(result, "message", None) or result)
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on error: %s", self._last_text)
self._last_text = None
def on_close(self) -> None:
logger.info("ASR closed")
if self._last_text:
self.session._schedule(self.session.handle_sentence(self._last_text))
logger.info("ASR flush last text on close: %s", self._last_text)
self._last_text = None
def on_event(self, result: RecognitionResult) -> None:
sentence = result.get_sentence()
if not sentence:
return
sentences = sentence if isinstance(sentence, list) else [sentence]
for sent in sentences:
text = sent.get("text") if isinstance(sent, dict) else None
if not text:
continue
is_end = False
if isinstance(sent, dict):
is_end = (
bool(sent.get("is_sentence_end"))
or bool(sent.get("sentence_end"))
or RecognitionResult.is_sentence_end(sent)
)
if is_end:
self.session._schedule(self.session.handle_sentence(text))
self._last_text = None
else:
self.session._schedule(self.session.send_signal({"type": "partial_asr", "text": text}))
self._last_text = text
logger.info("ASR event end=%s sentence=%s", is_end, sent)
async def authenticate_websocket(websocket: WebSocket) -> AuthedUser:
"""复用 HTTP 鉴权逻辑Authorization / X-Token / x_user_id调试"""
headers = websocket.headers
token = None
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = headers.get("x-token")
# 支持 query 携带
if not token:
token = websocket.query_params.get("token")
x_user_id = websocket.query_params.get("x_user_id")
if token:
payload = _fetch_user_from_php(token)
user_id = payload.get("id") or payload.get("user_id")
reg_step = payload.get("reg_step") or payload.get("stage") or 1
gender = payload.get("gender") or 0
nickname = payload.get("nickname") or payload.get("username") or ""
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户中心缺少用户ID")
return AuthedUser(
id=user_id,
reg_step=reg_step,
gender=gender,
nickname=nickname,
token=token,
)
if x_user_id is not None:
try:
uid = int(x_user_id)
except Exception:
uid = None
if uid is not None:
return AuthedUser(id=uid, reg_step=2, gender=0, nickname="debug-user", token="")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或未授权")
class VoiceCallSession:
def __init__(self, websocket: WebSocket, user: AuthedUser, require_ptt: bool = False):
self.websocket = websocket
self.user = user
self.require_ptt = require_ptt
self.mic_enabled = not require_ptt
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.asr_to_llm: asyncio.Queue[str] = asyncio.Queue()
self.llm_to_tts: asyncio.Queue[str] = asyncio.Queue()
self.is_speaking = False
self.lover: Optional[Lover] = None
self.db = SessionLocal()
self.voice_code: Optional[str] = None
self.history: List[dict] = [
{
"role": "system",
"content": self._compose_system_prompt(),
}
]
self.llm_task: Optional[asyncio.Task] = None
self.tts_task: Optional[asyncio.Task] = None
self.tts_stream_task: Optional[asyncio.Task] = None
self.silence_task: Optional[asyncio.Task] = None
self.cancel_event = asyncio.Event()
self.recognition: Optional[Recognition] = None
self.idle_task: Optional[asyncio.Task] = None
self.last_activity = time.time()
self.last_voice_activity = time.time()
self.has_voice_input = False
self.last_interrupt_time = 0.0
self.tts_first_chunk = True
async def start(self):
await self.websocket.accept()
self.loop = asyncio.get_running_loop()
# 预加载恋人与音色,避免在流式环节阻塞事件循环
self._prepare_profile()
# 启动 ASR
self._start_asr()
# 启动 LLM/TTS 后台任务
self.llm_task = asyncio.create_task(self._process_llm_loop())
self.tts_task = asyncio.create_task(self._process_tts_loop())
self.idle_task = asyncio.create_task(self._idle_watchdog())
self.silence_task = asyncio.create_task(self._silence_watchdog())
await self.send_signal({"type": "ready"})
if self.require_ptt:
await self.send_signal({"type": "info", "msg": "ptt_enabled"})
def _start_asr(self):
if Recognition is None:
raise HTTPException(status_code=500, detail="未安装 dashscope无法启动实时 ASR")
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
dashscope.api_key = settings.DASHSCOPE_API_KEY
callback = WSRecognitionCallback(self)
self.recognition = Recognition(
model=settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
format="pcm",
sample_rate=settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
api_key=settings.DASHSCOPE_API_KEY,
callback=callback,
)
logger.info(
"ASR started model=%s sample_rate=%s",
settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
)
self.recognition.start()
async def handle_sentence(self, text: str):
# 回合制AI 说话时忽略用户语音,提示稍后再说
if self.is_speaking:
await self.send_signal({"type": "info", "msg": "请等待 AI 说完再讲话"})
return
logger.info("Handle sentence: %s", text)
await self.asr_to_llm.put(text)
async def _process_llm_loop(self):
while True:
text = await self.asr_to_llm.get()
self.cancel_event.clear()
try:
await self._stream_llm(text)
except asyncio.CancelledError:
break
except Exception as exc:
logger.exception("LLM error", exc_info=exc)
await self.send_signal({"type": "error", "msg": "LLM 生成失败"})
self.is_speaking = False
async def _stream_llm(self, text: str):
self.history.append({"role": "user", "content": text})
# 控制历史长度
if len(self.history) > settings.VOICE_CALL_MAX_HISTORY:
self.history = self.history[-settings.VOICE_CALL_MAX_HISTORY :]
stream = chat_completion_stream(self.history)
self.is_speaking = True
self.tts_first_chunk = True
buffer = []
for chunk in stream:
if self.cancel_event.is_set():
break
buffer.append(chunk)
await self.llm_to_tts.put(chunk)
if not self.cancel_event.is_set():
await self.llm_to_tts.put(END_OF_TTS)
full_reply = "".join(buffer)
self.history.append({"role": "assistant", "content": full_reply})
if full_reply:
# 下行完整文本,便于前端展示/调试
await self.send_signal({"type": "reply_text", "text": full_reply})
else:
self.is_speaking = False
async def _process_tts_loop(self):
temp_buffer = []
punctuations = set(",。?!,.?!;")
while True:
token = await self.llm_to_tts.get()
if self.cancel_event.is_set():
temp_buffer = []
self.tts_first_chunk = True
continue
if token == END_OF_TTS:
# 将残余缓冲送出
if temp_buffer:
text_chunk = "".join(temp_buffer)
temp_buffer = []
clean_text = self._clean_tts_text(text_chunk)
if clean_text:
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
continue
self.tts_first_chunk = True
self.is_speaking = False
await self.send_signal({"type": "reply_end"})
continue
temp_buffer.append(token)
last_char = token[-1] if token else ""
threshold = 8 if self.tts_first_chunk else 18
if last_char in punctuations or len("".join(temp_buffer)) >= threshold:
text_chunk = "".join(temp_buffer)
temp_buffer = []
self.tts_first_chunk = False
clean_text = self._clean_tts_text(text_chunk)
if not clean_text:
continue
try:
async for chunk in self._synthesize_stream(clean_text):
if self.cancel_event.is_set():
break
await self.websocket.send_bytes(chunk)
self._touch()
except WebSocketDisconnect:
break
except Exception as exc:
logger.exception("TTS error", exc_info=exc)
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
self.is_speaking = False
# 不可达,但保留以防逻辑调整
async def _synthesize_stream(self, text: str):
"""
调用 cosyvoice v2 流式合成 chunk 返回
如流式不可用则回落一次性合成
"""
model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
voice = self._pick_voice_code() or settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
fmt = settings.VOICE_CALL_TTS_FORMAT.lower() if settings.VOICE_CALL_TTS_FORMAT else "mp3"
audio_format = AudioFormat.MP3_22050HZ_MONO_256KBPS if fmt == "mp3" else AudioFormat.PCM_16000HZ_MONO
# 直接同步合成,避免流式阻塞
audio_bytes, _fmt_name = synthesize(text, model=model, voice=voice, audio_format=audio_format) # type: ignore[arg-type]
yield audio_bytes
async def feed_audio(self, data: bytes):
if self.require_ptt and not self.mic_enabled:
# PTT 模式下未按住说话时丢弃音频
self._touch()
return
# 若之前 stop 过,则懒启动
if not (self.recognition and getattr(self.recognition, "_running", False)):
try:
self._start_asr()
except Exception as exc:
logger.error("ASR restart failed: %s", exc)
return
if self.recognition:
self.recognition.send_audio_frame(data)
logger.debug("recv audio chunk bytes=%s", len(data))
peak = self._peak_pcm16(data)
now = time.time()
if peak > 300: # 只用于活跃检测,不再触发打断
self.last_voice_activity = now
self.has_voice_input = True
self._touch()
def finalize_asr(self):
"""主动停止 ASR促使返回最终结果。"""
try:
if self.recognition:
self.recognition.stop()
logger.info("ASR stop requested manually")
except Exception as exc:
logger.warning("ASR stop failed: %s", exc)
async def set_mic_enabled(self, enabled: bool, flush: bool = False):
if not self.require_ptt:
return
self.mic_enabled = enabled
await self.send_signal({"type": "info", "msg": "mic_on" if enabled else "mic_off"})
if not enabled and flush:
self.finalize_asr()
def _schedule(self, coro):
if self.loop:
self.loop.call_soon_threadsafe(asyncio.create_task, coro)
def _pick_voice_code(self) -> Optional[str]:
"""根据恋人配置或默认音色选择 voice_code。"""
if self.voice_code:
return self.voice_code
self._prepare_profile()
return self.voice_code
async def _interrupt(self):
self.cancel_event.set()
# 清空队列
while not self.llm_to_tts.empty():
try:
self.llm_to_tts.get_nowait()
except Exception:
break
await self.send_signal({"type": "interrupt", "code": "interrupted", "msg": "AI 打断,停止播放"})
self.is_speaking = False
self.last_interrupt_time = time.time()
async def close(self):
if self.db:
try:
self.db.close()
except Exception:
pass
if self.recognition:
try:
self.recognition.stop()
except Exception:
pass
if self.llm_task:
self.llm_task.cancel()
if self.tts_task:
self.tts_task.cancel()
if self.tts_stream_task:
self.tts_stream_task.cancel()
if self.idle_task:
self.idle_task.cancel()
if self.silence_task:
self.silence_task.cancel()
if self.websocket.client_state == WebSocketState.CONNECTED:
await self.websocket.close()
async def send_signal(self, payload: dict):
if self.websocket.client_state != WebSocketState.CONNECTED:
return
try:
await self.websocket.send_text(json.dumps(payload, ensure_ascii=False))
self._touch()
except WebSocketDisconnect:
return
def _load_lover(self) -> Optional[Lover]:
if self.lover is not None:
return self.lover
try:
self.lover = self.db.query(Lover).filter(Lover.user_id == self.user.id).first()
except Exception as exc:
logger.warning("Load lover failed: %s", exc)
self.lover = None
return self.lover
def _compose_system_prompt(self) -> str:
parts = [
f"你是用户 {self.user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天,不要使用 Markdown 符号,不要输出表情、波浪线、星号或动作描述。",
"回复必须是对话内容,不要包含括号/星号/动作描写/舞台指令,不要用拟声词凑字数,保持简短自然的中文口语句子。",
"禁止涉政、违法、暴力、未成年相关内容。",
]
lover = self._load_lover()
if lover and lover.personality_prompt:
parts.append(f"人格设定:{lover.personality_prompt}")
return "\n".join(parts)
@staticmethod
def _clean_tts_text(text: str) -> str:
if not text:
return ""
# 去掉常见 Markdown/代码标记,保留文字内容
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
text = re.sub(r"`([^`]*)`", r"\1", text)
text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
text = re.sub(r"\*[^\*]{0,80}\*", "", text) # 去掉 *动作* 片段
text = re.sub(r"[~]+", "", text) # 去掉波浪线
text = text.replace("*", "")
text = re.sub(r"\s+", " ", text)
return text.strip()
def _prepare_profile(self) -> None:
"""预加载恋人和音色,避免在流式阶段阻塞事件循环。"""
try:
lover = self._load_lover()
if lover and lover.voice_id:
voice = self.db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first()
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
gender = None
if lover and lover.gender:
gender = lover.gender
if not gender:
gender = "female" if (self.user.gender or 0) == 1 else "male"
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
return
voice = (
self.db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender)
.order_by(VoiceLibrary.id.asc())
.first()
)
if voice and voice.voice_code:
self.voice_code = voice.voice_code
except Exception as exc:
logger.warning("Prepare profile failed: %s", exc)
def _touch(self):
self.last_activity = time.time()
async def _idle_watchdog(self):
timeout = settings.VOICE_CALL_IDLE_TIMEOUT or 0
if timeout <= 0:
return
try:
while True:
await asyncio.sleep(5)
if time.time() - self.last_activity > timeout:
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
async def _silence_watchdog(self):
"""长时间静默时关闭会话ASR 常驻不再因短静音 stop。"""
try:
while True:
await asyncio.sleep(1.0)
if time.time() - self.last_voice_activity > 60:
logger.info("Long silence, closing session")
await self.send_signal({"type": "error", "msg": "idle timeout"})
await self.close()
break
except asyncio.CancelledError:
return
@staticmethod
def _peak_pcm16(data: bytes) -> int:
"""快速估算 PCM 16bit 峰值幅度。"""
if not data:
return 0
view = memoryview(data)
# 每 2 字节一采样,取绝对值最大
max_val = 0
for i in range(0, len(view) - 1, 2):
sample = int.from_bytes(view[i : i + 2], "little", signed=True)
if sample < 0:
sample = -sample
if sample > max_val:
max_val = sample
return max_val
@router.websocket("/call")
async def voice_call(websocket: WebSocket):
try:
user = await authenticate_websocket(websocket)
except HTTPException as exc:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
ptt_param = (websocket.query_params.get("ptt") or "").strip().lower()
require_ptt = settings.VOICE_CALL_REQUIRE_PTT or ptt_param in ("1", "true", "yes", "on")
session = VoiceCallSession(websocket, user, require_ptt=require_ptt)
try:
await session.start()
except HTTPException as exc:
try:
await websocket.accept()
await websocket.send_text(json.dumps({"type": "error", "msg": exc.detail}))
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
except Exception:
pass
return
try:
while True:
msg = await websocket.receive()
if "bytes" in msg and msg["bytes"] is not None:
await session.feed_audio(msg["bytes"])
elif "text" in msg and msg["text"]:
# 简单心跳/信令
text = msg["text"].strip()
lower_text = text.lower()
if lower_text in ("mic_on", "ptt_on"):
await session.set_mic_enabled(True)
elif lower_text in ("mic_off", "ptt_off"):
await session.set_mic_enabled(False, flush=True)
elif text == "ping":
await websocket.send_text("pong")
elif text in ("end", "stop", "flush"):
session.finalize_asr()
await session.send_signal({"type": "info", "msg": "ASR stopped manually"})
else:
await session.send_signal({"type": "info", "msg": "文本消息已忽略"})
if msg.get("type") == "websocket.disconnect":
break
except WebSocketDisconnect:
pass
finally:
await session.close()