Ai_GirlFriend/lover/routers/chat.py
2026-01-31 19:15:41 +08:00

1165 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, date
from typing import List, Optional
import re
import random
import oss2
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import desc
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from ..config import settings
from ..db import get_db
from ..deps import AuthedUser, get_current_user
from ..llm import chat_completion
from ..models import ChatMessage, ChatSession, Lover, User, ChatFact, ChatSummary, VoiceLibrary # type: ignore
from ..response import ApiResponse, success_response
from ..tts import synthesize
from ..vision import describe_image
from .config import _parse_owned_voices
router = APIRouter(prefix="/chat", tags=["chat"])
class ChatSendIn(BaseModel):
session_id: int = Field(..., description="会话ID必填")
message: str = Field(..., min_length=1, max_length=4000)
class ChatSendOut(BaseModel):
session_id: int
user_message_id: int
lover_message_id: int
reply: str
inner_voice_enabled: bool
usage: Optional[dict] = None
chat_limit_daily: int
chat_used_today: int
class ChatSendImageOut(BaseModel):
session_id: int
user_message_id: int
lover_message_id: int
reply: str
inner_voice_enabled: bool
usage: Optional[dict] = None
chat_limit_daily: int
chat_used_today: int
class ChatConfigOut(BaseModel):
session_id: Optional[int] = None
chat_limit_daily: int
chat_used_today: int
inner_voice_enabled: bool
is_vip: bool
class MessageOut(BaseModel):
id: int
role: str
content: str
seq: int
created_at: datetime
content_type: Optional[str] = None
extra: Optional[dict] = None
# 预留给前端展示是否已有 TTS避免重复请求
tts_url: Optional[str] = None
tts_status: Optional[str] = None
tts_voice_id: Optional[int] = None
tts_model_id: Optional[str] = None
class TTSOut(BaseModel):
message_id: int
tts_url: Optional[str] = None
tts_status: str
tts_voice_id: Optional[int] = None
tts_model_id: Optional[str] = None
class SessionInitOut(BaseModel):
session_id: int
messages: List[MessageOut]
chat_limit_daily: int
chat_used_today: int
inner_voice_enabled: bool
is_vip: bool
class MessagesPageOut(BaseModel):
session_id: int
page: int
size: int
messages: List[MessageOut]
has_more: bool
def _reset_chat_quota(user_row: User):
today = date.today()
if user_row.chat_reset_date != today:
user_row.chat_reset_date = today
user_row.chat_used_today = 0
def _ensure_quota(user_row: User):
_reset_chat_quota(user_row)
limit = user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY
if user_row.chat_used_today >= limit:
raise HTTPException(status_code=400, detail="今日聊天次数已用完")
def _pick_llm_seed() -> int:
return random.randint(0, 2**31 - 1)
def _normalize_text(text: Optional[str]) -> str:
if not text:
return ""
cleaned = re.sub(r"\s+", " ", text)
return cleaned.strip()
_CONTEXT_MSG_OVERHEAD_TOKENS = 4
def _estimate_tokens_for_context(text: Optional[str]) -> int:
if not text:
return 0
ascii_chars = 0
for ch in text:
if ord(ch) < 128:
ascii_chars += 1
non_ascii_chars = len(text) - ascii_chars
tokens = non_ascii_chars + (ascii_chars + 3) // 4
return max(1, tokens)
def _estimate_message_tokens(msg: dict) -> int:
content = msg.get("content") if isinstance(msg, dict) else ""
return _estimate_tokens_for_context(content or "") + _CONTEXT_MSG_OVERHEAD_TOKENS
def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
if not text or max_tokens <= 0:
return ""
if _estimate_tokens_for_context(text) <= max_tokens:
return text
approx_tokens = _estimate_tokens_for_context(text)
ratio = max_tokens / max(1, approx_tokens)
new_len = max(1, int(len(text) * ratio))
trimmed = text[:new_len]
for _ in range(6):
if _estimate_tokens_for_context(trimmed) <= max_tokens:
break
new_len = max(1, int(new_len * 0.85))
trimmed = text[:new_len]
return trimmed
def _truncate_message_to_budget(msg: dict, max_tokens: int) -> dict:
content = msg.get("content") if isinstance(msg, dict) else ""
content_budget = max_tokens - _CONTEXT_MSG_OVERHEAD_TOKENS
trimmed_content = _truncate_text_to_tokens(content or "", max(0, content_budget))
new_msg = dict(msg)
new_msg["content"] = trimmed_content
return new_msg
def _trim_context_messages(messages: List[dict]) -> List[dict]:
max_tokens = settings.CHAT_CONTEXT_MAX_TOKENS or 0
if max_tokens <= 0:
return messages
# Split system prefix and dialogue messages.
first_non_system = next(
(idx for idx, msg in enumerate(messages) if msg.get("role") != "system"),
len(messages),
)
prefix = list(messages[:first_non_system])
convo = list(messages[first_non_system:])
def total_tokens() -> int:
return sum(_estimate_message_tokens(m) for m in prefix) + sum(_estimate_message_tokens(m) for m in convo)
total = total_tokens()
if total <= max_tokens:
return messages
# Drop optional system messages from the end (keep the base system prompt).
while len(prefix) > 1 and total > max_tokens:
removed = prefix.pop()
total -= _estimate_message_tokens(removed)
# Drop oldest dialogue messages, keep at least the latest one.
if convo:
start = 0
while start < len(convo) - 1 and total > max_tokens:
total -= _estimate_message_tokens(convo[start])
start += 1
if start:
convo = convo[start:]
# If still over budget, truncate the earliest remaining dialogue message.
if convo and total > max_tokens:
prefix_tokens = sum(_estimate_message_tokens(m) for m in prefix)
tail_tokens = sum(_estimate_message_tokens(m) for m in convo[1:])
allowed_for_first = max_tokens - prefix_tokens - tail_tokens
convo[0] = _truncate_message_to_budget(convo[0], allowed_for_first)
total = total_tokens()
if prefix and total > max_tokens:
# Last resort: truncate base system prompt.
allowed_for_system = max_tokens - sum(_estimate_message_tokens(m) for m in convo)
prefix[0] = _truncate_message_to_budget(prefix[0], allowed_for_system)
return prefix + convo
def _is_exact_repeat(last_msg: Optional[ChatMessage], reply_text: str) -> bool:
if not last_msg or last_msg.role != "lover":
return False
return _normalize_text(last_msg.content or "") == _normalize_text(reply_text)
def _fetch_session(db: Session, user_id: int, session_id: Optional[int]) -> Optional[ChatSession]:
if not session_id:
return None
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user_id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
return session
def _create_session(db: Session, user: AuthedUser, lover: Lover, inner_voice_enabled: bool) -> ChatSession:
now = datetime.utcnow()
session = ChatSession(
user_id=user.id,
lover_id=lover.id,
model=settings.LLM_MODEL or "qwen-flash",
status="active",
last_message_at=now,
created_at=now,
updated_at=now,
inner_voice_enabled=inner_voice_enabled,
)
db.add(session)
db.flush()
return session
def _build_system_prompt(lover: Lover, user: AuthedUser, inner_voice_enabled: bool) -> str:
# 基础设定:保持恋人角色
parts = [
f"你是用户 {user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖的语气聊天。",
"禁止涉政、违法、暴力、未成年相关内容。",
]
if lover.personality_prompt:
parts.append(f"人格设定:{lover.personality_prompt}")
parts.append("避免与上一条回复完全一致;若用户要求复述,请换一种说法简要转述。")
if inner_voice_enabled:
parts.append("在适当时机用全角括号(…)加入心理活动,正文与心声并存。")
else:
parts.append("不要输出括号(…)内的心理活动。")
return "\n".join(parts)
def _build_context_messages(
db: Session,
session: ChatSession,
lover: Lover,
user: AuthedUser,
inner_voice_enabled: bool,
) -> List[dict]:
messages: List[dict] = []
# 系统提示
messages.append({"role": "system", "content": _build_system_prompt(lover, user, inner_voice_enabled)})
# 最新摘要
summary = (
db.query(ChatSummary)
.filter(ChatSummary.session_id == session.id)
.order_by(ChatSummary.upto_seq.desc())
.first()
)
if summary:
messages.append({"role": "system", "content": f"对话摘要(累积至 seq {summary.upto_seq}{summary.summary_text}"})
# 画像要点(按权重、时间排序)
facts = (
db.query(ChatFact)
.filter(ChatFact.user_id == user.id, ChatFact.lover_id == lover.id)
.order_by(desc(ChatFact.weight), ChatFact.created_at.desc())
.limit(settings.CHAT_FACT_MAX_ROWS)
.all()
)
if facts:
fact_lines = [f"- {f.kind or 'fact'}: {f.content}" for f in facts]
messages.append({"role": "system", "content": "画像与重要事实:\n" + "\n".join(fact_lines)})
# 最近消息窗口(含最新用户消息)
history = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.seq.desc())
.limit(settings.CHAT_RECENT_WINDOW)
.all()
)
# 逆序结果,保持时间正序拼装
for msg in reversed(history):
role = "assistant" if msg.role == "lover" else "user" if msg.role == "user" else "system"
# 将图片消息转成文本描述放入上下文
if msg.content_type == "image":
caption = None
try:
caption = (msg.extra or {}).get("image_caption")
except Exception:
caption = None
if msg.role == "user":
user_nickname = user.nickname or "用户"
content = f"{user_nickname}给你发了一张图片,内容是:{caption or '图片已收到,但未能识别'}"
else:
content = f"恋人发送了一张图片,内容是:{caption or '图片'}"
else:
content = msg.content
messages.append({"role": role, "content": content})
return _trim_context_messages(messages)
def _is_vip(user_row: User) -> bool:
try:
now_ts = int(datetime.utcnow().timestamp())
return bool(user_row.vip_endtime and int(user_row.vip_endtime) > now_ts)
except Exception:
return False
def _get_or_create_primary_session(db: Session, user: AuthedUser, lover: Lover) -> ChatSession:
"""
获取用户当前唯一 active 会话;若不存在则创建。
若存在多个 active会选择最新的一个并将其他标记为 archived避免多 session。
"""
active_sessions = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
.with_for_update()
.order_by(ChatSession.created_at.desc())
.all()
)
if active_sessions:
primary = active_sessions[0]
others = active_sessions[1:]
for s in others:
s.status = "archived"
db.add(s)
db.flush()
return primary
# 创建新会话,心声开关取用户默认或全局默认
user_row = db.query(User).filter(User.id == user.id).first()
inner_voice = (
user_row.inner_voice_enabled
if user_row and user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
try:
session = _create_session(db, user, lover, inner_voice_enabled=inner_voice)
except IntegrityError:
# 并发创建或存在归档会话命中唯一键时,回滚并返回已存在的会话
db.rollback()
existing = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id)
.with_for_update()
.order_by(ChatSession.created_at.desc())
.first()
)
if existing:
if existing.status != "active":
existing.status = "active"
existing.updated_at = datetime.utcnow()
db.add(existing)
db.flush()
return existing
raise
# 将开场白写入首条消息(作为恋人第一条)
opening = lover.opening_line or "嗨,我是你的恋人,很高兴见到你~"
now = datetime.utcnow()
first_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=opening,
seq=1,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(first_msg)
session.last_message_at = now
db.add(session)
db.flush()
return session
def _strip_inner_voice_text(raw: Optional[str]) -> str:
"""
去除括号中的内心/旁白内容,便于生成 TTS。
"""
if not raw:
return ""
# 去掉全角括号内的内容(常用于心声/旁白)
cleaned = re.sub(r".*?", "", raw)
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
def _pick_available_voice(db: Session, lover: Lover, user_row: User) -> VoiceLibrary:
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
candidate: Optional[VoiceLibrary] = None
if lover.voice_id:
voice = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.id == lover.voice_id, VoiceLibrary.gender == lover.gender)
.first()
)
if voice:
# 付费音色需已拥有;免费/默认可直接使用
if (voice.price_gold or 0) > 0 and int(voice.id) not in owned_ids:
candidate = None
else:
candidate = voice
if not candidate:
candidate = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == lover.gender, VoiceLibrary.is_default.is_(True))
.first()
)
if not candidate:
candidate = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == lover.gender)
.order_by(VoiceLibrary.id.asc())
.first()
)
if not candidate:
raise HTTPException(status_code=404, detail="未找到可用音色")
if not candidate.voice_code:
raise HTTPException(status_code=500, detail="音色未配置 voice_code")
return candidate
def _upload_tts_to_oss(file_bytes: bytes, lover_id: int, message_id: int) -> str:
if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET:
raise HTTPException(status_code=500, detail="未配置 OSS Key")
if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT:
raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint")
object_name = f"lover/{lover_id}/tts/{message_id}.mp3"
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
try:
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
bucket.put_object(object_name, file_bytes)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"上传语音失败: {exc}") from exc
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn:
return f"{cdn.rstrip('/')}/{object_name}"
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
@router.get("/session/init", response_model=ApiResponse[SessionInitOut])
def init_session(
page: int = Query(1, ge=1),
size: int = Query(15, ge=1, le=100),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
session = _get_or_create_primary_session(db, user, lover)
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_reset_chat_quota(user_row)
messages_query = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.seq.desc())
)
total = messages_query.count()
records = messages_query.offset((page - 1) * size).limit(size).all()
# 逆序为正序输出
records = list(sorted(records, key=lambda m: m.seq or 0))
def to_message(msg: ChatMessage) -> MessageOut:
return MessageOut(
id=msg.id,
role=msg.role,
content=msg.content,
seq=msg.seq or 0,
created_at=msg.created_at,
content_type=getattr(msg, "content_type", None),
extra=getattr(msg, "extra", None),
tts_url=getattr(msg, "tts_url", None),
tts_status=getattr(msg, "tts_status", None),
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
inner_voice = (
session.inner_voice_enabled
if session.inner_voice_enabled is not None
else user_row.inner_voice_enabled
if user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
# 若用户已非 VIP但会话心声仍为开启则自动关闭保持一致
if inner_voice and not _is_vip(user_row):
inner_voice = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
return success_response(
SessionInitOut(
session_id=session.id,
messages=[to_message(m) for m in records],
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=bool(inner_voice),
is_vip=_is_vip(user_row),
)
)
@router.get("/messages", response_model=ApiResponse[MessagesPageOut])
def list_messages(
session_id: int,
page: int = Query(1, ge=1),
size: int = Query(15, ge=1, le=100),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
query = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.id.desc())
)
total = query.count()
records = query.offset((page - 1) * size).limit(size).all()
# 每页内部按 seq 正序返回
records = list(sorted(records, key=lambda m: m.seq or 0))
def to_message(msg: ChatMessage) -> MessageOut:
return MessageOut(
id=msg.id,
role=msg.role,
content=msg.content,
seq=msg.seq or 0,
created_at=msg.created_at,
content_type=getattr(msg, "content_type", None),
extra=getattr(msg, "extra", None),
tts_url=getattr(msg, "tts_url", None),
tts_status=getattr(msg, "tts_status", None),
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
has_more = page * size < total
return success_response(
MessagesPageOut(
session_id=session.id,
page=page,
size=size,
messages=[to_message(m) for m in records],
has_more=has_more,
)
)
@router.post("/messages/tts/{message_id}", response_model=ApiResponse[TTSOut])
def generate_message_tts(
message_id: int,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == message_id, ChatMessage.user_id == user.id)
.with_for_update()
.first()
)
if not msg:
raise HTTPException(status_code=404, detail="消息不存在")
if msg.role != "lover" or msg.content_type != "text":
raise HTTPException(status_code=400, detail="仅支持恋人文本消息生成语音")
lover = db.query(Lover).filter(Lover.id == msg.lover_id, Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在")
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
clean_text = _strip_inner_voice_text(msg.content or "")
if not clean_text:
raise HTTPException(status_code=400, detail="消息内容为空,无法生成语音")
if len(clean_text) > 80:
raise HTTPException(status_code=400, detail="文本长度超过80字符不支持生成语音")
if msg.tts_url and msg.tts_status == "succeeded":
return success_response(
TTSOut(
message_id=msg.id,
tts_url=msg.tts_url,
tts_status=msg.tts_status or "succeeded",
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
)
voice = _pick_available_voice(db, lover, user_row)
model = voice.tts_model_id or "cosyvoice-v2"
# 先标记 pending避免并发重复生成
msg.tts_status = "pending"
msg.tts_error = None
db.add(msg)
db.flush()
try:
audio_bytes, fmt_name = synthesize(
clean_text,
model=model,
voice=voice.voice_code,
)
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id)
except HTTPException as exc:
detail_text = str(exc.detail) if hasattr(exc, "detail") else str(exc)
fallback_done = False
url = ""
if "ModelNotFound" in detail_text or "not found" in detail_text:
fallback_voice = settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
fallback_model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
if fallback_voice != voice.voice_code or fallback_model != model:
try:
audio_bytes, fmt_name = synthesize(
clean_text,
model=fallback_model,
voice=fallback_voice,
)
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id)
msg.tts_voice_id = None # 兜底音色不绑定库ID
msg.tts_model_id = fallback_model
fallback_done = True
except Exception:
url = ""
if not url:
msg.tts_status = "failed"
msg.tts_error = detail_text[:255]
db.add(msg)
db.flush()
raise
except Exception as exc:
msg.tts_status = "failed"
msg.tts_error = str(exc)[:255]
db.add(msg)
db.flush()
raise HTTPException(status_code=502, detail="语音生成失败,请稍后重试") from exc
msg.tts_url = url
msg.tts_status = "succeeded"
msg.tts_voice_id = msg.tts_voice_id or voice.id
msg.tts_model_id = msg.tts_model_id or model
msg.tts_format = fmt_name
msg.tts_error = None
db.add(msg)
db.flush()
return success_response(
TTSOut(
message_id=msg.id,
tts_url=msg.tts_url,
tts_status=msg.tts_status or "succeeded",
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
)
class ChatSendImageIn(BaseModel):
session_id: int
image_url: str
@router.post("/send-image", response_model=ApiResponse[ChatSendImageOut])
def send_image_message(
payload: ChatSendImageIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
img_url = payload.image_url or ""
if not img_url.startswith("http"):
raise HTTPException(status_code=400, detail="图片URL格式不正确")
# Lover 必须存在
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# VIP 校验
if not _is_vip(user_row):
raise HTTPException(status_code=403, detail="仅 VIP 用户可发送图片")
_ensure_quota(user_row) # 额度检查,图片也计入
# 必须传入已有会话
session = _fetch_session(db, user.id, payload.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 禁止连发:最近一条如果是用户,则需等待回复
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.with_for_update()
.order_by(ChatMessage.seq.desc())
.first()
)
if last_msg and last_msg.role == "user":
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
# 新序号
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
now = datetime.utcnow()
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="image",
content=img_url,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
extra={},
)
db.add(user_msg)
db.flush()
# 图片理解
caption = "图片已收到"
try:
caption = describe_image(img_url)
except HTTPException as exc:
# 保留失败信息,但不阻断对话
user_msg.extra = {"image_caption": caption, "image_error": str(exc.detail)}
else:
user_msg.extra = {"image_caption": caption}
db.add(user_msg)
db.flush()
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
if inner_voice_enabled and not _is_vip(user_row):
inner_voice_enabled = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
llm_result = chat_completion(
context_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
reply_text = llm_result.content or ""
# 关闭心声时去掉括号旁白,避免“内心戏”
if not inner_voice_enabled:
reply_text = _strip_inner_voice_text(reply_text)
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
if _is_exact_repeat(last_msg, reply_text):
retry_messages = list(context_messages)
retry_messages.append(
{
"role": "system",
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
}
)
retry_result = chat_completion(
retry_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
retry_text = retry_result.content or ""
if not inner_voice_enabled:
retry_text = _strip_inner_voice_text(retry_text)
if max_chars > 0 and len(retry_text) > max_chars:
retry_text = retry_text[:max_chars]
if retry_text and not _is_exact_repeat(last_msg, retry_text):
llm_result = retry_result
reply_text = retry_text
else:
base_text = last_msg.content or reply_text
reply_text = f"我刚才说的是:{base_text}"
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=reply_text,
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
token_input=(llm_result.usage or {}).get("input_tokens"),
token_output=(llm_result.usage or {}).get("output_tokens"),
)
db.add(lover_msg)
# 更新会话与额度
session.last_message_at = datetime.utcnow()
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
db.add(session)
db.add(user_row)
db.flush()
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
return success_response(
ChatSendImageOut(
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
reply=reply_text,
inner_voice_enabled=bool(inner_voice_enabled),
usage=llm_result.usage,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
)
)
def _maybe_generate_summary(db: Session, session: ChatSession, upto_seq: int):
"""
若累计新增消息达到阈值,则生成滚动摘要写入 nf_chat_summary。
"""
threshold = settings.CHAT_SUMMARY_TRIGGER_MSGS or 0
token_threshold = settings.CHAT_SUMMARY_TRIGGER_TOKENS or 0
if threshold <= 0 and token_threshold <= 0:
return
last_summary = (
db.query(ChatSummary)
.filter(ChatSummary.session_id == session.id)
.order_by(ChatSummary.upto_seq.desc())
.first()
)
last_upto = last_summary.upto_seq if last_summary else 0
new_msgs = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id, ChatMessage.seq > last_upto, ChatMessage.seq <= upto_seq)
.order_by(ChatMessage.seq.asc())
.all()
)
if threshold > 0 and len(new_msgs) < threshold:
# 条数未达阈值
pass
else:
# 若条数阈值已满足或未设置,则继续
pass
# 粗略估算 token按英文 4 chars/token中文 1.5 chars/token 简化)
def estimate_tokens(text: str) -> int:
if not text:
return 0
# 假设每 2 字符 ~ 1 token中英文混合粗估
return max(1, len(text) // 2)
total_tokens = sum(estimate_tokens(m.content or "") for m in new_msgs)
if threshold > 0 and len(new_msgs) < threshold and (token_threshold <= 0 or total_tokens < token_threshold):
return
if token_threshold > 0 and total_tokens < token_threshold and (threshold <= 0 or len(new_msgs) < threshold):
return
# 组装摘要输入
lines = []
for m in new_msgs:
speaker = "用户" if m.role == "user" else "恋人" if m.role == "lover" else "系统"
lines.append(f"{speaker}: {m.content}")
prompt = [
{
"role": "system",
"content": "你是对话摘要助手请用简洁中文总结关键事件、情绪、约定与待办50~120字。",
},
{"role": "user", "content": "\n".join(lines)},
]
result = chat_completion(prompt, temperature=0.3, max_tokens=300)
summary = ChatSummary(
session_id=session.id,
upto_seq=upto_seq,
summary_text=result.content or "",
model=settings.LLM_MODEL or "qwen-flash",
token_input=(result.usage or {}).get("input_tokens"),
token_output=(result.usage or {}).get("output_tokens"),
created_at=datetime.utcnow(),
)
db.add(summary)
db.flush()
@router.post("/send", response_model=ApiResponse[ChatSendOut])
def send_message(
payload: ChatSendIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
# Lover 必须存在
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
# 用户额度检查
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_ensure_quota(user_row)
# 必须传入已有会话
session = _fetch_session(db, user.id, payload.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 禁止连发:最近一条如果是用户,则需等待回复
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.with_for_update()
.order_by(ChatMessage.seq.desc())
.first()
)
if last_msg and last_msg.role == "user":
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
# 新序号
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
now = datetime.utcnow()
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="text",
content=payload.message,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
# 非 VIP 则强制关闭心声,避免历史开关残留
if inner_voice_enabled and not _is_vip(user_row):
inner_voice_enabled = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
llm_result = chat_completion(
context_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
reply_text = llm_result.content or ""
if not inner_voice_enabled:
reply_text = _strip_inner_voice_text(reply_text)
# 字数上限控制0 表示不截断)
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
if _is_exact_repeat(last_msg, reply_text):
retry_messages = list(context_messages)
retry_messages.append(
{
"role": "system",
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
}
)
retry_result = chat_completion(
retry_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
retry_text = retry_result.content or ""
if not inner_voice_enabled:
retry_text = _strip_inner_voice_text(retry_text)
if max_chars > 0 and len(retry_text) > max_chars:
retry_text = retry_text[:max_chars]
if retry_text and not _is_exact_repeat(last_msg, retry_text):
llm_result = retry_result
reply_text = retry_text
else:
base_text = last_msg.content or reply_text
reply_text = f"我刚才说的是:{base_text}"
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=reply_text,
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
token_input=(llm_result.usage or {}).get("input_tokens"),
token_output=(llm_result.usage or {}).get("output_tokens"),
)
db.add(lover_msg)
# 更新会话与额度
session.last_message_at = datetime.utcnow()
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
db.add(session)
db.add(user_row)
db.flush()
# 触发滚动摘要
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
return success_response(
ChatSendOut(
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
reply=reply_text,
inner_voice_enabled=bool(inner_voice_enabled),
usage=llm_result.usage,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
)
)
class InnerVoiceIn(BaseModel):
session_id: int
enabled: bool
@router.put("/inner-voice", response_model=ApiResponse[ChatConfigOut])
def update_inner_voice(
payload: InnerVoiceIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, payload.session_id)
# 如果会话不存在,且用户希望开启/关闭心声,则自动创建会话以便落开关
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 仅 VIP 可开启心声
user_row = db.query(User).filter(User.id == user.id).first()
if payload.enabled:
if not user_row or not _is_vip(user_row):
raise HTTPException(status_code=403, detail="仅 VIP 用户可开启心声")
session.inner_voice_enabled = payload.enabled
db.add(session)
db.flush()
_reset_chat_quota(user_row) # refresh today count if needed
return success_response(
ChatConfigOut(
session_id=session.id,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=payload.enabled,
is_vip=_is_vip(user_row),
)
)
@router.get("/config", response_model=ApiResponse[ChatConfigOut])
def get_chat_config(
session_id: Optional[int] = Query(default=None),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, session_id) if session_id else None
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_reset_chat_quota(user_row)
inner_voice = (
session.inner_voice_enabled
if session and session.inner_voice_enabled is not None
else user_row.inner_voice_enabled
if user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
if inner_voice and not _is_vip(user_row):
inner_voice = False
if session:
session.inner_voice_enabled = False
db.add(session)
db.flush()
return success_response(
ChatConfigOut(
session_id=session.id if session else None,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=bool(inner_voice),
is_vip=_is_vip(user_row),
)
)