import asyncio import json import logging import re import time from typing import List, Optional import requests import dashscope from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status from fastapi.websockets import WebSocketState from ..config import settings from ..deps import AuthedUser, _fetch_user_from_php from ..llm import chat_completion_stream from ..tts import synthesize from ..db import SessionLocal from ..models import Lover, VoiceLibrary try: from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult except Exception: # dashscope 未安装时提供兜底 Recognition = None RecognitionCallback = object # type: ignore RecognitionResult = object # type: ignore try: from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer, ResultCallback except Exception: AudioFormat = None # type: ignore SpeechSynthesizer = None # type: ignore ResultCallback = object # type: ignore router = APIRouter(prefix="/voice", tags=["voice"]) logger = logging.getLogger("voice_call") logger.setLevel(logging.INFO) END_OF_TTS = "<>" class WSRecognitionCallback(RecognitionCallback): # type: ignore[misc] """ASR 回调,将句子级结果推入会话队列。""" def __init__(self, session: "VoiceCallSession"): super().__init__() self.session = session self._last_text: Optional[str] = None def on_open(self) -> None: logger.info("ASR connection opened") def on_complete(self) -> None: logger.info("ASR complete") if self._last_text: # 将最后的部分作为一句结束,防止没有 end 标记时丢失 self.session._schedule(self.session.handle_sentence(self._last_text)) logger.info("ASR flush last text on complete: %s", self._last_text) self._last_text = None def on_error(self, result: RecognitionResult) -> None: logger.error("ASR error: %s", getattr(result, "message", None) or result) if self._last_text: self.session._schedule(self.session.handle_sentence(self._last_text)) logger.info("ASR flush last text on error: %s", self._last_text) self._last_text = None def on_close(self) -> None: logger.info("ASR closed") if self._last_text: self.session._schedule(self.session.handle_sentence(self._last_text)) logger.info("ASR flush last text on close: %s", self._last_text) self._last_text = None def on_event(self, result: RecognitionResult) -> None: sentence = result.get_sentence() if not sentence: return sentences = sentence if isinstance(sentence, list) else [sentence] for sent in sentences: text = sent.get("text") if isinstance(sent, dict) else None if not text: continue is_end = False if isinstance(sent, dict): is_end = ( bool(sent.get("is_sentence_end")) or bool(sent.get("sentence_end")) or RecognitionResult.is_sentence_end(sent) ) if is_end: self.session._schedule(self.session.handle_sentence(text)) self._last_text = None else: self.session._schedule(self.session.send_signal({"type": "partial_asr", "text": text})) self._last_text = text logger.info("ASR event end=%s sentence=%s", is_end, sent) async def authenticate_websocket(websocket: WebSocket) -> AuthedUser: """复用 HTTP 鉴权逻辑:Authorization / X-Token / x_user_id(调试)""" headers = websocket.headers token = None auth_header = headers.get("authorization") if auth_header and auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() if not token: token = headers.get("x-token") # 支持 query 携带 if not token: token = websocket.query_params.get("token") x_user_id = websocket.query_params.get("x_user_id") if token: payload = _fetch_user_from_php(token) user_id = payload.get("id") or payload.get("user_id") reg_step = payload.get("reg_step") or payload.get("stage") or 1 gender = payload.get("gender") or 0 nickname = payload.get("nickname") or payload.get("username") or "" if not user_id: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户中心缺少用户ID") return AuthedUser( id=user_id, reg_step=reg_step, gender=gender, nickname=nickname, token=token, ) if x_user_id is not None: try: uid = int(x_user_id) except Exception: uid = None if uid is not None: return AuthedUser(id=uid, reg_step=2, gender=0, nickname="debug-user", token="") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或未授权") class VoiceCallSession: def __init__(self, websocket: WebSocket, user: AuthedUser, require_ptt: bool = False): self.websocket = websocket self.user = user self.require_ptt = require_ptt self.mic_enabled = not require_ptt self.loop: Optional[asyncio.AbstractEventLoop] = None self.asr_to_llm: asyncio.Queue[str] = asyncio.Queue() self.llm_to_tts: asyncio.Queue[str] = asyncio.Queue() self.is_speaking = False self.lover: Optional[Lover] = None self.db = SessionLocal() self.voice_code: Optional[str] = None self.history: List[dict] = [ { "role": "system", "content": self._compose_system_prompt(), } ] self.llm_task: Optional[asyncio.Task] = None self.tts_task: Optional[asyncio.Task] = None self.tts_stream_task: Optional[asyncio.Task] = None self.silence_task: Optional[asyncio.Task] = None self.cancel_event = asyncio.Event() self.recognition: Optional[Recognition] = None self.idle_task: Optional[asyncio.Task] = None self.last_activity = time.time() self.last_voice_activity = time.time() self.has_voice_input = False self.last_interrupt_time = 0.0 self.tts_first_chunk = True async def start(self): await self.websocket.accept() self.loop = asyncio.get_running_loop() # 预加载恋人与音色,避免在流式环节阻塞事件循环 self._prepare_profile() # 启动 ASR self._start_asr() # 启动 LLM/TTS 后台任务 self.llm_task = asyncio.create_task(self._process_llm_loop()) self.tts_task = asyncio.create_task(self._process_tts_loop()) self.idle_task = asyncio.create_task(self._idle_watchdog()) self.silence_task = asyncio.create_task(self._silence_watchdog()) await self.send_signal({"type": "ready"}) if self.require_ptt: await self.send_signal({"type": "info", "msg": "ptt_enabled"}) def _start_asr(self): if Recognition is None: raise HTTPException(status_code=500, detail="未安装 dashscope,无法启动实时 ASR") if not settings.DASHSCOPE_API_KEY: raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY") dashscope.api_key = settings.DASHSCOPE_API_KEY callback = WSRecognitionCallback(self) self.recognition = Recognition( model=settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2", format="pcm", sample_rate=settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000, api_key=settings.DASHSCOPE_API_KEY, callback=callback, ) logger.info( "ASR started model=%s sample_rate=%s", settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2", settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000, ) self.recognition.start() async def handle_sentence(self, text: str): # 回合制:AI 说话时忽略用户语音,提示稍后再说 if self.is_speaking: await self.send_signal({"type": "info", "msg": "请等待 AI 说完再讲话"}) return logger.info("Handle sentence: %s", text) await self.asr_to_llm.put(text) async def _process_llm_loop(self): while True: text = await self.asr_to_llm.get() self.cancel_event.clear() try: await self._stream_llm(text) except asyncio.CancelledError: break except Exception as exc: logger.exception("LLM error", exc_info=exc) await self.send_signal({"type": "error", "msg": "LLM 生成失败"}) self.is_speaking = False async def _stream_llm(self, text: str): self.history.append({"role": "user", "content": text}) # 控制历史长度 if len(self.history) > settings.VOICE_CALL_MAX_HISTORY: self.history = self.history[-settings.VOICE_CALL_MAX_HISTORY :] stream = chat_completion_stream(self.history) self.is_speaking = True self.tts_first_chunk = True buffer = [] for chunk in stream: if self.cancel_event.is_set(): break buffer.append(chunk) await self.llm_to_tts.put(chunk) if not self.cancel_event.is_set(): await self.llm_to_tts.put(END_OF_TTS) full_reply = "".join(buffer) self.history.append({"role": "assistant", "content": full_reply}) if full_reply: # 下行完整文本,便于前端展示/调试 await self.send_signal({"type": "reply_text", "text": full_reply}) else: self.is_speaking = False async def _process_tts_loop(self): temp_buffer = [] punctuations = set(",。?!,.?!;;") while True: token = await self.llm_to_tts.get() if self.cancel_event.is_set(): temp_buffer = [] self.tts_first_chunk = True continue if token == END_OF_TTS: # 将残余缓冲送出 if temp_buffer: text_chunk = "".join(temp_buffer) temp_buffer = [] clean_text = self._clean_tts_text(text_chunk) if clean_text: try: async for chunk in self._synthesize_stream(clean_text): if self.cancel_event.is_set(): break await self.websocket.send_bytes(chunk) self._touch() except WebSocketDisconnect: break except Exception as exc: logger.exception("TTS error", exc_info=exc) await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"}) self.is_speaking = False continue self.tts_first_chunk = True self.is_speaking = False await self.send_signal({"type": "reply_end"}) continue temp_buffer.append(token) last_char = token[-1] if token else "" threshold = 8 if self.tts_first_chunk else 18 if last_char in punctuations or len("".join(temp_buffer)) >= threshold: text_chunk = "".join(temp_buffer) temp_buffer = [] self.tts_first_chunk = False clean_text = self._clean_tts_text(text_chunk) if not clean_text: continue try: async for chunk in self._synthesize_stream(clean_text): if self.cancel_event.is_set(): break await self.websocket.send_bytes(chunk) self._touch() except WebSocketDisconnect: break except Exception as exc: logger.exception("TTS error", exc_info=exc) await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"}) self.is_speaking = False # 不可达,但保留以防逻辑调整 async def _synthesize_stream(self, text: str): """ 调用 cosyvoice v2 流式合成,逐 chunk 返回。 如流式不可用则回落一次性合成。 """ model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2" voice = self._pick_voice_code() or settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2" fmt = settings.VOICE_CALL_TTS_FORMAT.lower() if settings.VOICE_CALL_TTS_FORMAT else "mp3" audio_format = AudioFormat.MP3_22050HZ_MONO_256KBPS if fmt == "mp3" else AudioFormat.PCM_16000HZ_MONO # 直接同步合成,避免流式阻塞 audio_bytes, _fmt_name = synthesize(text, model=model, voice=voice, audio_format=audio_format) # type: ignore[arg-type] yield audio_bytes async def feed_audio(self, data: bytes): if self.require_ptt and not self.mic_enabled: # PTT 模式下未按住说话时丢弃音频 self._touch() return # 若之前 stop 过,则懒启动 if not (self.recognition and getattr(self.recognition, "_running", False)): try: self._start_asr() except Exception as exc: logger.error("ASR restart failed: %s", exc) return if self.recognition: self.recognition.send_audio_frame(data) logger.debug("recv audio chunk bytes=%s", len(data)) peak = self._peak_pcm16(data) now = time.time() if peak > 300: # 只用于活跃检测,不再触发打断 self.last_voice_activity = now self.has_voice_input = True self._touch() def finalize_asr(self): """主动停止 ASR,促使返回最终结果。""" try: if self.recognition: self.recognition.stop() logger.info("ASR stop requested manually") except Exception as exc: logger.warning("ASR stop failed: %s", exc) async def set_mic_enabled(self, enabled: bool, flush: bool = False): if not self.require_ptt: return self.mic_enabled = enabled await self.send_signal({"type": "info", "msg": "mic_on" if enabled else "mic_off"}) if not enabled and flush: self.finalize_asr() def _schedule(self, coro): if self.loop: self.loop.call_soon_threadsafe(asyncio.create_task, coro) def _pick_voice_code(self) -> Optional[str]: """根据恋人配置或默认音色选择 voice_code。""" if self.voice_code: return self.voice_code self._prepare_profile() return self.voice_code async def _interrupt(self): self.cancel_event.set() # 清空队列 while not self.llm_to_tts.empty(): try: self.llm_to_tts.get_nowait() except Exception: break await self.send_signal({"type": "interrupt", "code": "interrupted", "msg": "AI 打断,停止播放"}) self.is_speaking = False self.last_interrupt_time = time.time() async def close(self): if self.db: try: self.db.close() except Exception: pass if self.recognition: try: self.recognition.stop() except Exception: pass if self.llm_task: self.llm_task.cancel() if self.tts_task: self.tts_task.cancel() if self.tts_stream_task: self.tts_stream_task.cancel() if self.idle_task: self.idle_task.cancel() if self.silence_task: self.silence_task.cancel() if self.websocket.client_state == WebSocketState.CONNECTED: await self.websocket.close() async def send_signal(self, payload: dict): if self.websocket.client_state != WebSocketState.CONNECTED: return try: await self.websocket.send_text(json.dumps(payload, ensure_ascii=False)) self._touch() except WebSocketDisconnect: return def _load_lover(self) -> Optional[Lover]: if self.lover is not None: return self.lover try: self.lover = self.db.query(Lover).filter(Lover.user_id == self.user.id).first() except Exception as exc: logger.warning("Load lover failed: %s", exc) self.lover = None return self.lover def _compose_system_prompt(self) -> str: parts = [ f"你是用户 {self.user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天,不要使用 Markdown 符号,不要输出表情、波浪线、星号或动作描述。", "回复必须是对话内容,不要包含括号/星号/动作描写/舞台指令,不要用拟声词凑字数,保持简短自然的中文口语句子。", "禁止涉政、违法、暴力、未成年相关内容。", ] lover = self._load_lover() if lover and lover.personality_prompt: parts.append(f"人格设定:{lover.personality_prompt}") return "\n".join(parts) @staticmethod def _clean_tts_text(text: str) -> str: if not text: return "" # 去掉常见 Markdown/代码标记,保留文字内容 text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) text = re.sub(r"`([^`]*)`", r"\1", text) text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text) text = re.sub(r"\*[^\*]{0,80}\*", "", text) # 去掉 *动作* 片段 text = re.sub(r"[~~]+", "", text) # 去掉波浪线 text = text.replace("*", "") text = re.sub(r"\s+", " ", text) return text.strip() def _prepare_profile(self) -> None: """预加载恋人和音色,避免在流式阶段阻塞事件循环。""" try: lover = self._load_lover() if lover and lover.voice_id: voice = self.db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first() if voice and voice.voice_code: self.voice_code = voice.voice_code return gender = None if lover and lover.gender: gender = lover.gender if not gender: gender = "female" if (self.user.gender or 0) == 1 else "male" voice = ( self.db.query(VoiceLibrary) .filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True)) .first() ) if voice and voice.voice_code: self.voice_code = voice.voice_code return voice = ( self.db.query(VoiceLibrary) .filter(VoiceLibrary.gender == gender) .order_by(VoiceLibrary.id.asc()) .first() ) if voice and voice.voice_code: self.voice_code = voice.voice_code except Exception as exc: logger.warning("Prepare profile failed: %s", exc) def _touch(self): self.last_activity = time.time() async def _idle_watchdog(self): timeout = settings.VOICE_CALL_IDLE_TIMEOUT or 0 if timeout <= 0: return try: while True: await asyncio.sleep(5) if time.time() - self.last_activity > timeout: await self.send_signal({"type": "error", "msg": "idle timeout"}) await self.close() break except asyncio.CancelledError: return async def _silence_watchdog(self): """长时间静默时关闭会话,ASR 常驻不再因短静音 stop。""" try: while True: await asyncio.sleep(1.0) if time.time() - self.last_voice_activity > 60: logger.info("Long silence, closing session") await self.send_signal({"type": "error", "msg": "idle timeout"}) await self.close() break except asyncio.CancelledError: return @staticmethod def _peak_pcm16(data: bytes) -> int: """快速估算 PCM 16bit 峰值幅度。""" if not data: return 0 view = memoryview(data) # 每 2 字节一采样,取绝对值最大 max_val = 0 for i in range(0, len(view) - 1, 2): sample = int.from_bytes(view[i : i + 2], "little", signed=True) if sample < 0: sample = -sample if sample > max_val: max_val = sample return max_val @router.websocket("/call") async def voice_call(websocket: WebSocket): try: user = await authenticate_websocket(websocket) except HTTPException as exc: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return ptt_param = (websocket.query_params.get("ptt") or "").strip().lower() require_ptt = settings.VOICE_CALL_REQUIRE_PTT or ptt_param in ("1", "true", "yes", "on") session = VoiceCallSession(websocket, user, require_ptt=require_ptt) try: await session.start() except HTTPException as exc: try: await websocket.accept() await websocket.send_text(json.dumps({"type": "error", "msg": exc.detail})) await websocket.close(code=status.WS_1011_INTERNAL_ERROR) except Exception: pass return try: while True: msg = await websocket.receive() if "bytes" in msg and msg["bytes"] is not None: await session.feed_audio(msg["bytes"]) elif "text" in msg and msg["text"]: # 简单心跳/信令 text = msg["text"].strip() lower_text = text.lower() if lower_text in ("mic_on", "ptt_on"): await session.set_mic_enabled(True) elif lower_text in ("mic_off", "ptt_off"): await session.set_mic_enabled(False, flush=True) elif text == "ping": await websocket.send_text("pong") elif text in ("end", "stop", "flush"): session.finalize_asr() await session.send_signal({"type": "info", "msg": "ASR stopped manually"}) else: await session.send_signal({"type": "info", "msg": "文本消息已忽略"}) if msg.get("type") == "websocket.disconnect": break except WebSocketDisconnect: pass finally: await session.close()