619 lines
24 KiB
Python
619 lines
24 KiB
Python
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from typing import List, Optional
|
||
|
||
import requests
|
||
import dashscope
|
||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect, status
|
||
from fastapi.websockets import WebSocketState
|
||
|
||
from ..config import settings
|
||
from ..deps import AuthedUser, get_current_user, _fetch_user_from_php
|
||
from ..llm import chat_completion_stream
|
||
from ..tts import synthesize
|
||
from ..db import SessionLocal
|
||
from ..models import Lover, VoiceLibrary, User
|
||
|
||
try:
|
||
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
|
||
except Exception: # dashscope 未安装时提供兜底
|
||
Recognition = None
|
||
RecognitionCallback = object # type: ignore
|
||
RecognitionResult = object # type: ignore
|
||
try:
|
||
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer, ResultCallback
|
||
except Exception:
|
||
AudioFormat = None # type: ignore
|
||
SpeechSynthesizer = None # type: ignore
|
||
ResultCallback = object # type: ignore
|
||
|
||
|
||
router = APIRouter(prefix="/voice", tags=["voice"])
|
||
logger = logging.getLogger("voice_call")
|
||
logger.setLevel(logging.INFO)
|
||
END_OF_TTS = "<<VOICE_CALL_TTS_END>>"
|
||
|
||
|
||
@router.get("/call/duration")
|
||
async def get_call_duration(user: AuthedUser = Depends(get_current_user)):
|
||
"""获取用户的语音通话时长配置"""
|
||
from ..db import SessionLocal
|
||
from ..models import User
|
||
from datetime import datetime
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 检查 VIP 状态(vip_endtime 是 Unix 时间戳)
|
||
current_timestamp = int(datetime.utcnow().timestamp())
|
||
is_vip = user_row.vip_endtime and user_row.vip_endtime > current_timestamp
|
||
|
||
if is_vip:
|
||
duration = 0 # 0 表示无限制
|
||
else:
|
||
duration = 300000 # 普通用户 5 分钟
|
||
|
||
from ..response import success_response
|
||
return success_response({
|
||
"duration": duration,
|
||
"is_vip": is_vip
|
||
})
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
class WSRecognitionCallback(RecognitionCallback): # type: ignore[misc]
|
||
"""ASR 回调,将句子级结果推入会话队列。"""
|
||
|
||
def __init__(self, session: "VoiceCallSession"):
|
||
super().__init__()
|
||
self.session = session
|
||
self._last_text: Optional[str] = None
|
||
|
||
def on_open(self) -> None:
|
||
logger.info("ASR connection opened")
|
||
|
||
def on_complete(self) -> None:
|
||
logger.info("ASR complete")
|
||
if self._last_text:
|
||
# 将最后的部分作为一句结束,防止没有 end 标记时丢失
|
||
self.session._schedule(self.session.handle_sentence(self._last_text))
|
||
logger.info("ASR flush last text on complete: %s", self._last_text)
|
||
self._last_text = None
|
||
|
||
def on_error(self, result: RecognitionResult) -> None:
|
||
logger.error("ASR error: %s", getattr(result, "message", None) or result)
|
||
if self._last_text:
|
||
self.session._schedule(self.session.handle_sentence(self._last_text))
|
||
logger.info("ASR flush last text on error: %s", self._last_text)
|
||
self._last_text = None
|
||
|
||
def on_close(self) -> None:
|
||
logger.info("ASR closed")
|
||
if self._last_text:
|
||
self.session._schedule(self.session.handle_sentence(self._last_text))
|
||
logger.info("ASR flush last text on close: %s", self._last_text)
|
||
self._last_text = None
|
||
|
||
def on_event(self, result: RecognitionResult) -> None:
|
||
sentence = result.get_sentence()
|
||
if not sentence:
|
||
return
|
||
|
||
sentences = sentence if isinstance(sentence, list) else [sentence]
|
||
for sent in sentences:
|
||
text = sent.get("text") if isinstance(sent, dict) else None
|
||
if not text:
|
||
continue
|
||
is_end = False
|
||
if isinstance(sent, dict):
|
||
is_end = (
|
||
bool(sent.get("is_sentence_end"))
|
||
or bool(sent.get("sentence_end"))
|
||
or RecognitionResult.is_sentence_end(sent)
|
||
)
|
||
if is_end:
|
||
self.session._schedule(self.session.handle_sentence(text))
|
||
self._last_text = None
|
||
else:
|
||
self.session._schedule(self.session.send_signal({"type": "partial_asr", "text": text}))
|
||
self._last_text = text
|
||
logger.info("ASR event end=%s sentence=%s", is_end, sent)
|
||
|
||
|
||
async def authenticate_websocket(websocket: WebSocket) -> AuthedUser:
|
||
"""复用 HTTP 鉴权逻辑:Authorization / X-Token / x_user_id(调试)"""
|
||
headers = websocket.headers
|
||
token = None
|
||
auth_header = headers.get("authorization")
|
||
if auth_header and auth_header.lower().startswith("bearer "):
|
||
token = auth_header.split(" ", 1)[1].strip()
|
||
if not token:
|
||
token = headers.get("x-token")
|
||
|
||
# 支持 query 携带
|
||
if not token:
|
||
token = websocket.query_params.get("token")
|
||
x_user_id = websocket.query_params.get("x_user_id")
|
||
|
||
if token:
|
||
payload = _fetch_user_from_php(token)
|
||
user_id = payload.get("id") or payload.get("user_id")
|
||
reg_step = payload.get("reg_step") or payload.get("stage") or 1
|
||
gender = payload.get("gender") or 0
|
||
nickname = payload.get("nickname") or payload.get("username") or ""
|
||
if not user_id:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户中心缺少用户ID")
|
||
return AuthedUser(
|
||
id=user_id,
|
||
reg_step=reg_step,
|
||
gender=gender,
|
||
nickname=nickname,
|
||
token=token,
|
||
)
|
||
|
||
if x_user_id is not None:
|
||
try:
|
||
uid = int(x_user_id)
|
||
except Exception:
|
||
uid = None
|
||
if uid is not None:
|
||
return AuthedUser(id=uid, reg_step=2, gender=0, nickname="debug-user", token="")
|
||
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或未授权")
|
||
|
||
|
||
class VoiceCallSession:
|
||
def __init__(self, websocket: WebSocket, user: AuthedUser, require_ptt: bool = False):
|
||
self.websocket = websocket
|
||
self.user = user
|
||
self.require_ptt = require_ptt
|
||
self.mic_enabled = not require_ptt
|
||
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||
self.asr_to_llm: asyncio.Queue[str] = asyncio.Queue()
|
||
self.llm_to_tts: asyncio.Queue[str] = asyncio.Queue()
|
||
self.is_speaking = False
|
||
self.lover: Optional[Lover] = None
|
||
self.db = SessionLocal()
|
||
self.voice_code: Optional[str] = None
|
||
self.history: List[dict] = [
|
||
{
|
||
"role": "system",
|
||
"content": self._compose_system_prompt(),
|
||
}
|
||
]
|
||
self.llm_task: Optional[asyncio.Task] = None
|
||
self.tts_task: Optional[asyncio.Task] = None
|
||
self.tts_stream_task: Optional[asyncio.Task] = None
|
||
self.silence_task: Optional[asyncio.Task] = None
|
||
self.cancel_event = asyncio.Event()
|
||
self.recognition: Optional[Recognition] = None
|
||
self.idle_task: Optional[asyncio.Task] = None
|
||
self.last_activity = time.time()
|
||
self.last_voice_activity = time.time()
|
||
self.has_voice_input = False
|
||
self.last_interrupt_time = 0.0
|
||
self.tts_first_chunk = True
|
||
|
||
async def start(self):
|
||
await self.websocket.accept()
|
||
self.loop = asyncio.get_running_loop()
|
||
# 预加载恋人与音色,避免在流式环节阻塞事件循环
|
||
self._prepare_profile()
|
||
# 启动 ASR
|
||
self._start_asr()
|
||
# 启动 LLM/TTS 后台任务
|
||
self.llm_task = asyncio.create_task(self._process_llm_loop())
|
||
self.tts_task = asyncio.create_task(self._process_tts_loop())
|
||
self.idle_task = asyncio.create_task(self._idle_watchdog())
|
||
self.silence_task = asyncio.create_task(self._silence_watchdog())
|
||
await self.send_signal({"type": "ready"})
|
||
if self.require_ptt:
|
||
await self.send_signal({"type": "info", "msg": "ptt_enabled"})
|
||
|
||
def _start_asr(self):
|
||
if Recognition is None:
|
||
raise HTTPException(status_code=500, detail="未安装 dashscope,无法启动实时 ASR")
|
||
if not settings.DASHSCOPE_API_KEY:
|
||
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
|
||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||
callback = WSRecognitionCallback(self)
|
||
self.recognition = Recognition(
|
||
model=settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
|
||
format="pcm",
|
||
sample_rate=settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
|
||
api_key=settings.DASHSCOPE_API_KEY,
|
||
callback=callback,
|
||
)
|
||
logger.info(
|
||
"ASR started model=%s sample_rate=%s",
|
||
settings.VOICE_CALL_ASR_MODEL or "paraformer-realtime-v2",
|
||
settings.VOICE_CALL_ASR_SAMPLE_RATE or 16000,
|
||
)
|
||
self.recognition.start()
|
||
|
||
async def handle_sentence(self, text: str):
|
||
# 回合制:AI 说话时忽略用户语音,提示稍后再说
|
||
if self.is_speaking:
|
||
await self.send_signal({"type": "info", "msg": "请等待 AI 说完再讲话"})
|
||
return
|
||
logger.info("Handle sentence: %s", text)
|
||
await self.asr_to_llm.put(text)
|
||
|
||
async def _process_llm_loop(self):
|
||
while True:
|
||
text = await self.asr_to_llm.get()
|
||
self.cancel_event.clear()
|
||
try:
|
||
await self._stream_llm(text)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as exc:
|
||
logger.exception("LLM error", exc_info=exc)
|
||
await self.send_signal({"type": "error", "msg": "LLM 生成失败"})
|
||
self.is_speaking = False
|
||
|
||
async def _stream_llm(self, text: str):
|
||
self.history.append({"role": "user", "content": text})
|
||
# 控制历史长度
|
||
if len(self.history) > settings.VOICE_CALL_MAX_HISTORY:
|
||
self.history = self.history[-settings.VOICE_CALL_MAX_HISTORY :]
|
||
stream = chat_completion_stream(self.history)
|
||
self.is_speaking = True
|
||
self.tts_first_chunk = True
|
||
buffer = []
|
||
for chunk in stream:
|
||
if self.cancel_event.is_set():
|
||
break
|
||
buffer.append(chunk)
|
||
await self.llm_to_tts.put(chunk)
|
||
if not self.cancel_event.is_set():
|
||
await self.llm_to_tts.put(END_OF_TTS)
|
||
full_reply = "".join(buffer)
|
||
self.history.append({"role": "assistant", "content": full_reply})
|
||
if full_reply:
|
||
# 下行完整文本,便于前端展示/调试
|
||
await self.send_signal({"type": "reply_text", "text": full_reply})
|
||
else:
|
||
self.is_speaking = False
|
||
|
||
async def _process_tts_loop(self):
|
||
temp_buffer = []
|
||
punctuations = set(",。?!,.?!;;")
|
||
while True:
|
||
token = await self.llm_to_tts.get()
|
||
if self.cancel_event.is_set():
|
||
temp_buffer = []
|
||
self.tts_first_chunk = True
|
||
continue
|
||
|
||
if token == END_OF_TTS:
|
||
# 将残余缓冲送出
|
||
if temp_buffer:
|
||
text_chunk = "".join(temp_buffer)
|
||
temp_buffer = []
|
||
clean_text = self._clean_tts_text(text_chunk)
|
||
if clean_text:
|
||
try:
|
||
async for chunk in self._synthesize_stream(clean_text):
|
||
if self.cancel_event.is_set():
|
||
break
|
||
await self.websocket.send_bytes(chunk)
|
||
self._touch()
|
||
except WebSocketDisconnect:
|
||
break
|
||
except Exception as exc:
|
||
logger.exception("TTS error", exc_info=exc)
|
||
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
|
||
self.is_speaking = False
|
||
continue
|
||
self.tts_first_chunk = True
|
||
self.is_speaking = False
|
||
await self.send_signal({"type": "reply_end"})
|
||
continue
|
||
|
||
temp_buffer.append(token)
|
||
last_char = token[-1] if token else ""
|
||
threshold = 8 if self.tts_first_chunk else 18
|
||
if last_char in punctuations or len("".join(temp_buffer)) >= threshold:
|
||
text_chunk = "".join(temp_buffer)
|
||
temp_buffer = []
|
||
self.tts_first_chunk = False
|
||
clean_text = self._clean_tts_text(text_chunk)
|
||
if not clean_text:
|
||
continue
|
||
try:
|
||
async for chunk in self._synthesize_stream(clean_text):
|
||
if self.cancel_event.is_set():
|
||
break
|
||
await self.websocket.send_bytes(chunk)
|
||
self._touch()
|
||
except WebSocketDisconnect:
|
||
break
|
||
except Exception as exc:
|
||
logger.exception("TTS error", exc_info=exc)
|
||
await self.send_signal({"type": "error", "code": "tts_failed", "msg": "TTS 合成失败"})
|
||
self.is_speaking = False
|
||
# 不可达,但保留以防逻辑调整
|
||
|
||
async def _synthesize_stream(self, text: str):
|
||
"""
|
||
调用 cosyvoice v2 流式合成,逐 chunk 返回。
|
||
如流式不可用则回落一次性合成。
|
||
"""
|
||
model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
|
||
voice = self._pick_voice_code() or settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
|
||
fmt = settings.VOICE_CALL_TTS_FORMAT.lower() if settings.VOICE_CALL_TTS_FORMAT else "mp3"
|
||
audio_format = AudioFormat.MP3_22050HZ_MONO_256KBPS if fmt == "mp3" else AudioFormat.PCM_16000HZ_MONO
|
||
|
||
# 直接同步合成,避免流式阻塞
|
||
audio_bytes, _fmt_name = synthesize(text, model=model, voice=voice, audio_format=audio_format) # type: ignore[arg-type]
|
||
yield audio_bytes
|
||
|
||
async def feed_audio(self, data: bytes):
|
||
if self.require_ptt and not self.mic_enabled:
|
||
# PTT 模式下未按住说话时丢弃音频
|
||
self._touch()
|
||
return
|
||
# 若之前 stop 过,则懒启动
|
||
if not (self.recognition and getattr(self.recognition, "_running", False)):
|
||
try:
|
||
self._start_asr()
|
||
except Exception as exc:
|
||
logger.error("ASR restart failed: %s", exc)
|
||
return
|
||
if self.recognition:
|
||
self.recognition.send_audio_frame(data)
|
||
logger.debug("recv audio chunk bytes=%s", len(data))
|
||
peak = self._peak_pcm16(data)
|
||
now = time.time()
|
||
if peak > 300: # 只用于活跃检测,不再触发打断
|
||
self.last_voice_activity = now
|
||
self.has_voice_input = True
|
||
self._touch()
|
||
|
||
def finalize_asr(self):
|
||
"""主动停止 ASR,促使返回最终结果。"""
|
||
try:
|
||
if self.recognition:
|
||
self.recognition.stop()
|
||
logger.info("ASR stop requested manually")
|
||
except Exception as exc:
|
||
logger.warning("ASR stop failed: %s", exc)
|
||
|
||
async def set_mic_enabled(self, enabled: bool, flush: bool = False):
|
||
if not self.require_ptt:
|
||
return
|
||
self.mic_enabled = enabled
|
||
await self.send_signal({"type": "info", "msg": "mic_on" if enabled else "mic_off"})
|
||
if not enabled and flush:
|
||
self.finalize_asr()
|
||
|
||
def _schedule(self, coro):
|
||
if self.loop:
|
||
self.loop.call_soon_threadsafe(asyncio.create_task, coro)
|
||
|
||
def _pick_voice_code(self) -> Optional[str]:
|
||
"""根据恋人配置或默认音色选择 voice_code。"""
|
||
if self.voice_code:
|
||
return self.voice_code
|
||
self._prepare_profile()
|
||
return self.voice_code
|
||
|
||
async def _interrupt(self):
|
||
self.cancel_event.set()
|
||
# 清空队列
|
||
while not self.llm_to_tts.empty():
|
||
try:
|
||
self.llm_to_tts.get_nowait()
|
||
except Exception:
|
||
break
|
||
await self.send_signal({"type": "interrupt", "code": "interrupted", "msg": "AI 打断,停止播放"})
|
||
self.is_speaking = False
|
||
self.last_interrupt_time = time.time()
|
||
|
||
async def close(self):
|
||
if self.db:
|
||
try:
|
||
self.db.close()
|
||
except Exception:
|
||
pass
|
||
if self.recognition:
|
||
try:
|
||
self.recognition.stop()
|
||
except Exception:
|
||
pass
|
||
if self.llm_task:
|
||
self.llm_task.cancel()
|
||
if self.tts_task:
|
||
self.tts_task.cancel()
|
||
if self.tts_stream_task:
|
||
self.tts_stream_task.cancel()
|
||
if self.idle_task:
|
||
self.idle_task.cancel()
|
||
if self.silence_task:
|
||
self.silence_task.cancel()
|
||
if self.websocket.client_state == WebSocketState.CONNECTED:
|
||
await self.websocket.close()
|
||
|
||
async def send_signal(self, payload: dict):
|
||
if self.websocket.client_state != WebSocketState.CONNECTED:
|
||
return
|
||
try:
|
||
await self.websocket.send_text(json.dumps(payload, ensure_ascii=False))
|
||
self._touch()
|
||
except WebSocketDisconnect:
|
||
return
|
||
|
||
def _load_lover(self) -> Optional[Lover]:
|
||
if self.lover is not None:
|
||
return self.lover
|
||
try:
|
||
self.lover = self.db.query(Lover).filter(Lover.user_id == self.user.id).first()
|
||
except Exception as exc:
|
||
logger.warning("Load lover failed: %s", exc)
|
||
self.lover = None
|
||
return self.lover
|
||
|
||
def _compose_system_prompt(self) -> str:
|
||
parts = [
|
||
f"你是用户 {self.user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖、口语化的短句聊天,不要使用 Markdown 符号,不要输出表情、波浪线、星号或动作描述。",
|
||
"回复必须是对话内容,不要包含括号/星号/动作描写/舞台指令,不要用拟声词凑字数,保持简短自然的中文口语句子。",
|
||
"禁止涉政、违法、暴力、未成年相关内容。",
|
||
]
|
||
lover = self._load_lover()
|
||
if lover and lover.personality_prompt:
|
||
parts.append(f"人格设定:{lover.personality_prompt}")
|
||
return "\n".join(parts)
|
||
|
||
@staticmethod
|
||
def _clean_tts_text(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
# 去掉常见 Markdown/代码标记,保留文字内容
|
||
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
|
||
text = re.sub(r"`([^`]*)`", r"\1", text)
|
||
text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
|
||
text = re.sub(r"\*[^\*]{0,80}\*", "", text) # 去掉 *动作* 片段
|
||
text = re.sub(r"[~~]+", "", text) # 去掉波浪线
|
||
text = text.replace("*", "")
|
||
text = re.sub(r"\s+", " ", text)
|
||
return text.strip()
|
||
|
||
def _prepare_profile(self) -> None:
|
||
"""预加载恋人和音色,避免在流式阶段阻塞事件循环。"""
|
||
try:
|
||
lover = self._load_lover()
|
||
if lover and lover.voice_id:
|
||
voice = self.db.query(VoiceLibrary).filter(VoiceLibrary.id == lover.voice_id).first()
|
||
if voice and voice.voice_code:
|
||
self.voice_code = voice.voice_code
|
||
return
|
||
|
||
gender = None
|
||
if lover and lover.gender:
|
||
gender = lover.gender
|
||
if not gender:
|
||
gender = "female" if (self.user.gender or 0) == 1 else "male"
|
||
|
||
voice = (
|
||
self.db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
|
||
.first()
|
||
)
|
||
if voice and voice.voice_code:
|
||
self.voice_code = voice.voice_code
|
||
return
|
||
voice = (
|
||
self.db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.gender == gender)
|
||
.order_by(VoiceLibrary.id.asc())
|
||
.first()
|
||
)
|
||
if voice and voice.voice_code:
|
||
self.voice_code = voice.voice_code
|
||
except Exception as exc:
|
||
logger.warning("Prepare profile failed: %s", exc)
|
||
|
||
def _touch(self):
|
||
self.last_activity = time.time()
|
||
|
||
async def _idle_watchdog(self):
|
||
timeout = settings.VOICE_CALL_IDLE_TIMEOUT or 0
|
||
if timeout <= 0:
|
||
return
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(5)
|
||
if time.time() - self.last_activity > timeout:
|
||
await self.send_signal({"type": "error", "msg": "idle timeout"})
|
||
await self.close()
|
||
break
|
||
except asyncio.CancelledError:
|
||
return
|
||
|
||
async def _silence_watchdog(self):
|
||
"""长时间静默时关闭会话,ASR 常驻不再因短静音 stop。"""
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(1.0)
|
||
if time.time() - self.last_voice_activity > 60:
|
||
logger.info("Long silence, closing session")
|
||
await self.send_signal({"type": "error", "msg": "idle timeout"})
|
||
await self.close()
|
||
break
|
||
except asyncio.CancelledError:
|
||
return
|
||
|
||
@staticmethod
|
||
def _peak_pcm16(data: bytes) -> int:
|
||
"""快速估算 PCM 16bit 峰值幅度。"""
|
||
if not data:
|
||
return 0
|
||
view = memoryview(data)
|
||
# 每 2 字节一采样,取绝对值最大
|
||
max_val = 0
|
||
for i in range(0, len(view) - 1, 2):
|
||
sample = int.from_bytes(view[i : i + 2], "little", signed=True)
|
||
if sample < 0:
|
||
sample = -sample
|
||
if sample > max_val:
|
||
max_val = sample
|
||
return max_val
|
||
|
||
|
||
@router.websocket("/call")
|
||
async def voice_call(websocket: WebSocket):
|
||
try:
|
||
user = await authenticate_websocket(websocket)
|
||
except HTTPException as exc:
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||
return
|
||
|
||
ptt_param = (websocket.query_params.get("ptt") or "").strip().lower()
|
||
require_ptt = settings.VOICE_CALL_REQUIRE_PTT or ptt_param in ("1", "true", "yes", "on")
|
||
session = VoiceCallSession(websocket, user, require_ptt=require_ptt)
|
||
try:
|
||
await session.start()
|
||
except HTTPException as exc:
|
||
try:
|
||
await websocket.accept()
|
||
await websocket.send_text(json.dumps({"type": "error", "msg": exc.detail}))
|
||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
||
except Exception:
|
||
pass
|
||
return
|
||
|
||
try:
|
||
while True:
|
||
msg = await websocket.receive()
|
||
if "bytes" in msg and msg["bytes"] is not None:
|
||
await session.feed_audio(msg["bytes"])
|
||
elif "text" in msg and msg["text"]:
|
||
# 简单心跳/信令
|
||
text = msg["text"].strip()
|
||
lower_text = text.lower()
|
||
if lower_text in ("mic_on", "ptt_on"):
|
||
await session.set_mic_enabled(True)
|
||
elif lower_text in ("mic_off", "ptt_off"):
|
||
await session.set_mic_enabled(False, flush=True)
|
||
elif text == "ping":
|
||
await websocket.send_text("pong")
|
||
elif text in ("end", "stop", "flush"):
|
||
session.finalize_asr()
|
||
await session.send_signal({"type": "info", "msg": "ASR stopped manually"})
|
||
else:
|
||
await session.send_signal({"type": "info", "msg": "文本消息已忽略"})
|
||
if msg.get("type") == "websocket.disconnect":
|
||
break
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
await session.close()
|