1295 lines
44 KiB
Python
1295 lines
44 KiB
Python
from datetime import datetime, date
|
||
from typing import List, Optional
|
||
import re
|
||
import random
|
||
|
||
import oss2
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import desc
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy.exc import IntegrityError
|
||
|
||
from ..config import settings
|
||
from ..db import get_db
|
||
from ..deps import AuthedUser, get_current_user
|
||
from ..llm import chat_completion
|
||
from ..models import ChatMessage, ChatSession, Lover, User, ChatFact, ChatSummary, VoiceLibrary # type: ignore
|
||
from ..response import ApiResponse, success_response
|
||
from ..tts import synthesize
|
||
from ..vision import describe_image
|
||
from .config import _parse_owned_voices
|
||
|
||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||
|
||
|
||
class ChatSendIn(BaseModel):
|
||
session_id: int = Field(..., description="会话ID,必填")
|
||
message: str = Field(..., min_length=1, max_length=4000)
|
||
|
||
|
||
class ChatSendOut(BaseModel):
|
||
session_id: int
|
||
user_message_id: int
|
||
lover_message_id: int
|
||
reply: str
|
||
inner_voice_enabled: bool
|
||
usage: Optional[dict] = None
|
||
chat_limit_daily: int
|
||
chat_used_today: int
|
||
|
||
|
||
class ChatSendImageOut(BaseModel):
|
||
session_id: int
|
||
user_message_id: int
|
||
lover_message_id: int
|
||
reply: str
|
||
inner_voice_enabled: bool
|
||
usage: Optional[dict] = None
|
||
chat_limit_daily: int
|
||
chat_used_today: int
|
||
|
||
|
||
|
||
|
||
class ChatConfigOut(BaseModel):
|
||
session_id: Optional[int] = None
|
||
chat_limit_daily: int
|
||
chat_used_today: int
|
||
inner_voice_enabled: bool
|
||
is_vip: bool
|
||
|
||
|
||
class MessageOut(BaseModel):
|
||
id: int
|
||
role: str
|
||
content: str
|
||
seq: int
|
||
created_at: datetime
|
||
content_type: Optional[str] = None
|
||
extra: Optional[dict] = None
|
||
# 预留给前端展示是否已有 TTS,避免重复请求
|
||
tts_url: Optional[str] = None
|
||
tts_status: Optional[str] = None
|
||
tts_voice_id: Optional[int] = None
|
||
tts_model_id: Optional[str] = None
|
||
|
||
|
||
class TTSOut(BaseModel):
|
||
message_id: int
|
||
tts_url: Optional[str] = None
|
||
tts_status: str
|
||
tts_voice_id: Optional[int] = None
|
||
tts_model_id: Optional[str] = None
|
||
|
||
|
||
class SessionInitOut(BaseModel):
|
||
session_id: int
|
||
messages: List[MessageOut]
|
||
chat_limit_daily: int
|
||
chat_used_today: int
|
||
inner_voice_enabled: bool
|
||
is_vip: bool
|
||
|
||
|
||
class MessagesPageOut(BaseModel):
|
||
session_id: int
|
||
page: int
|
||
size: int
|
||
messages: List[MessageOut]
|
||
has_more: bool
|
||
|
||
|
||
def _reset_chat_quota(user_row: User):
|
||
today = date.today()
|
||
if user_row.chat_reset_date != today:
|
||
user_row.chat_reset_date = today
|
||
user_row.chat_used_today = 0
|
||
|
||
|
||
def _ensure_quota(user_row: User):
|
||
_reset_chat_quota(user_row)
|
||
limit = user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY
|
||
if user_row.chat_used_today >= limit:
|
||
raise HTTPException(status_code=400, detail="今日聊天次数已用完")
|
||
|
||
|
||
def _pick_llm_seed() -> int:
|
||
return random.randint(0, 2**31 - 1)
|
||
|
||
|
||
def _normalize_text(text: Optional[str]) -> str:
|
||
if not text:
|
||
return ""
|
||
cleaned = re.sub(r"\s+", " ", text)
|
||
return cleaned.strip()
|
||
|
||
|
||
_CONTEXT_MSG_OVERHEAD_TOKENS = 4
|
||
|
||
|
||
def _estimate_tokens_for_context(text: Optional[str]) -> int:
|
||
if not text:
|
||
return 0
|
||
ascii_chars = 0
|
||
for ch in text:
|
||
if ord(ch) < 128:
|
||
ascii_chars += 1
|
||
non_ascii_chars = len(text) - ascii_chars
|
||
tokens = non_ascii_chars + (ascii_chars + 3) // 4
|
||
return max(1, tokens)
|
||
|
||
|
||
def _estimate_message_tokens(msg: dict) -> int:
|
||
content = msg.get("content") if isinstance(msg, dict) else ""
|
||
return _estimate_tokens_for_context(content or "") + _CONTEXT_MSG_OVERHEAD_TOKENS
|
||
|
||
|
||
def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
|
||
if not text or max_tokens <= 0:
|
||
return ""
|
||
if _estimate_tokens_for_context(text) <= max_tokens:
|
||
return text
|
||
approx_tokens = _estimate_tokens_for_context(text)
|
||
ratio = max_tokens / max(1, approx_tokens)
|
||
new_len = max(1, int(len(text) * ratio))
|
||
trimmed = text[:new_len]
|
||
for _ in range(6):
|
||
if _estimate_tokens_for_context(trimmed) <= max_tokens:
|
||
break
|
||
new_len = max(1, int(new_len * 0.85))
|
||
trimmed = text[:new_len]
|
||
return trimmed
|
||
|
||
|
||
def _truncate_message_to_budget(msg: dict, max_tokens: int) -> dict:
|
||
content = msg.get("content") if isinstance(msg, dict) else ""
|
||
content_budget = max_tokens - _CONTEXT_MSG_OVERHEAD_TOKENS
|
||
trimmed_content = _truncate_text_to_tokens(content or "", max(0, content_budget))
|
||
new_msg = dict(msg)
|
||
new_msg["content"] = trimmed_content
|
||
return new_msg
|
||
|
||
|
||
def _trim_context_messages(messages: List[dict]) -> List[dict]:
|
||
max_tokens = settings.CHAT_CONTEXT_MAX_TOKENS or 0
|
||
if max_tokens <= 0:
|
||
return messages
|
||
|
||
# Split system prefix and dialogue messages.
|
||
first_non_system = next(
|
||
(idx for idx, msg in enumerate(messages) if msg.get("role") != "system"),
|
||
len(messages),
|
||
)
|
||
prefix = list(messages[:first_non_system])
|
||
convo = list(messages[first_non_system:])
|
||
|
||
def total_tokens() -> int:
|
||
return sum(_estimate_message_tokens(m) for m in prefix) + sum(_estimate_message_tokens(m) for m in convo)
|
||
|
||
total = total_tokens()
|
||
if total <= max_tokens:
|
||
return messages
|
||
|
||
# Drop optional system messages from the end (keep the base system prompt).
|
||
while len(prefix) > 1 and total > max_tokens:
|
||
removed = prefix.pop()
|
||
total -= _estimate_message_tokens(removed)
|
||
|
||
# Drop oldest dialogue messages, keep at least the latest one.
|
||
if convo:
|
||
start = 0
|
||
while start < len(convo) - 1 and total > max_tokens:
|
||
total -= _estimate_message_tokens(convo[start])
|
||
start += 1
|
||
if start:
|
||
convo = convo[start:]
|
||
|
||
# If still over budget, truncate the earliest remaining dialogue message.
|
||
if convo and total > max_tokens:
|
||
prefix_tokens = sum(_estimate_message_tokens(m) for m in prefix)
|
||
tail_tokens = sum(_estimate_message_tokens(m) for m in convo[1:])
|
||
allowed_for_first = max_tokens - prefix_tokens - tail_tokens
|
||
convo[0] = _truncate_message_to_budget(convo[0], allowed_for_first)
|
||
|
||
total = total_tokens()
|
||
if prefix and total > max_tokens:
|
||
# Last resort: truncate base system prompt.
|
||
allowed_for_system = max_tokens - sum(_estimate_message_tokens(m) for m in convo)
|
||
prefix[0] = _truncate_message_to_budget(prefix[0], allowed_for_system)
|
||
|
||
return prefix + convo
|
||
|
||
|
||
def _is_exact_repeat(last_msg: Optional[ChatMessage], reply_text: str) -> bool:
|
||
if not last_msg or last_msg.role != "lover":
|
||
return False
|
||
return _normalize_text(last_msg.content or "") == _normalize_text(reply_text)
|
||
|
||
|
||
def _fetch_session(db: Session, user_id: int, session_id: Optional[int]) -> Optional[ChatSession]:
|
||
if not session_id:
|
||
return None
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id, ChatSession.user_id == user_id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
return session
|
||
|
||
|
||
def _create_session(db: Session, user: AuthedUser, lover: Lover, inner_voice_enabled: bool) -> ChatSession:
|
||
now = datetime.utcnow()
|
||
session = ChatSession(
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
status="active",
|
||
last_message_at=now,
|
||
created_at=now,
|
||
updated_at=now,
|
||
inner_voice_enabled=inner_voice_enabled,
|
||
)
|
||
db.add(session)
|
||
db.flush()
|
||
return session
|
||
|
||
|
||
def _build_system_prompt(lover: Lover, user: AuthedUser, inner_voice_enabled: bool) -> str:
|
||
# 基础设定:保持恋人角色
|
||
parts = [
|
||
f"你是用户 {user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖的语气聊天。",
|
||
"禁止涉政、违法、暴力、未成年相关内容。",
|
||
]
|
||
if lover.personality_prompt:
|
||
parts.append(f"人格设定:{lover.personality_prompt}")
|
||
parts.append("避免与上一条回复完全一致;若用户要求复述,请换一种说法简要转述。")
|
||
if inner_voice_enabled:
|
||
parts.append("在适当时机用全角括号(…)加入心理活动,正文与心声并存。")
|
||
else:
|
||
parts.append("不要输出括号(…)内的心理活动。")
|
||
return "\n".join(parts)
|
||
|
||
|
||
def _build_context_messages(
|
||
db: Session,
|
||
session: ChatSession,
|
||
lover: Lover,
|
||
user: AuthedUser,
|
||
inner_voice_enabled: bool,
|
||
) -> List[dict]:
|
||
messages: List[dict] = []
|
||
|
||
# 系统提示
|
||
messages.append({"role": "system", "content": _build_system_prompt(lover, user, inner_voice_enabled)})
|
||
|
||
# 最新摘要
|
||
summary = (
|
||
db.query(ChatSummary)
|
||
.filter(ChatSummary.session_id == session.id)
|
||
.order_by(ChatSummary.upto_seq.desc())
|
||
.first()
|
||
)
|
||
if summary:
|
||
messages.append({"role": "system", "content": f"对话摘要(累积至 seq {summary.upto_seq}):{summary.summary_text}"})
|
||
|
||
# 画像要点(按权重、时间排序)
|
||
facts = (
|
||
db.query(ChatFact)
|
||
.filter(ChatFact.user_id == user.id, ChatFact.lover_id == lover.id)
|
||
.order_by(desc(ChatFact.weight), ChatFact.created_at.desc())
|
||
.limit(settings.CHAT_FACT_MAX_ROWS)
|
||
.all()
|
||
)
|
||
if facts:
|
||
fact_lines = [f"- {f.kind or 'fact'}: {f.content}" for f in facts]
|
||
messages.append({"role": "system", "content": "画像与重要事实:\n" + "\n".join(fact_lines)})
|
||
|
||
# 最近消息窗口(含最新用户消息)
|
||
history = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.order_by(ChatMessage.seq.desc())
|
||
.limit(settings.CHAT_RECENT_WINDOW)
|
||
.all()
|
||
)
|
||
# 逆序结果,保持时间正序拼装
|
||
for msg in reversed(history):
|
||
role = "assistant" if msg.role == "lover" else "user" if msg.role == "user" else "system"
|
||
# 将图片消息转成文本描述放入上下文
|
||
if msg.content_type == "image":
|
||
caption = None
|
||
try:
|
||
caption = (msg.extra or {}).get("image_caption")
|
||
except Exception:
|
||
caption = None
|
||
if msg.role == "user":
|
||
user_nickname = user.nickname or "用户"
|
||
content = f"{user_nickname}给你发了一张图片,内容是:{caption or '图片已收到,但未能识别'}"
|
||
else:
|
||
content = f"恋人发送了一张图片,内容是:{caption or '图片'}"
|
||
else:
|
||
content = msg.content
|
||
|
||
# 如果消息被编辑过,添加特殊标记提醒AI
|
||
if msg.is_edited and msg.original_content:
|
||
if msg.role == "lover":
|
||
# AI的回复被用户纠正
|
||
content = f"[用户纠正] {content}\n(原回复:{msg.original_content})\n注意:用户认为原回复不准确,已纠正为上述内容,请以用户纠正的内容为准。"
|
||
else:
|
||
# 用户消息被编辑
|
||
content = f"{content}\n[已编辑,原内容:{msg.original_content}]"
|
||
|
||
messages.append({"role": role, "content": content})
|
||
|
||
return _trim_context_messages(messages)
|
||
|
||
|
||
def _is_vip(user_row: User) -> bool:
|
||
try:
|
||
now_ts = int(datetime.utcnow().timestamp())
|
||
return bool(user_row.vip_endtime and int(user_row.vip_endtime) > now_ts)
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _get_or_create_primary_session(db: Session, user: AuthedUser, lover: Lover) -> ChatSession:
|
||
"""
|
||
获取用户当前唯一 active 会话;若不存在则创建。
|
||
若存在多个 active,会选择最新的一个并将其他标记为 archived,避免多 session。
|
||
"""
|
||
active_sessions = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
|
||
.with_for_update()
|
||
.order_by(ChatSession.created_at.desc())
|
||
.all()
|
||
)
|
||
if active_sessions:
|
||
primary = active_sessions[0]
|
||
others = active_sessions[1:]
|
||
for s in others:
|
||
s.status = "archived"
|
||
db.add(s)
|
||
db.flush()
|
||
return primary
|
||
|
||
# 创建新会话,心声开关取用户默认或全局默认
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
inner_voice = (
|
||
user_row.inner_voice_enabled
|
||
if user_row and user_row.inner_voice_enabled is not None
|
||
else settings.INNER_VOICE_DEFAULT
|
||
)
|
||
try:
|
||
session = _create_session(db, user, lover, inner_voice_enabled=inner_voice)
|
||
except IntegrityError:
|
||
# 并发创建或存在归档会话命中唯一键时,回滚并返回已存在的会话
|
||
db.rollback()
|
||
existing = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id)
|
||
.with_for_update()
|
||
.order_by(ChatSession.created_at.desc())
|
||
.first()
|
||
)
|
||
if existing:
|
||
if existing.status != "active":
|
||
existing.status = "active"
|
||
existing.updated_at = datetime.utcnow()
|
||
db.add(existing)
|
||
db.flush()
|
||
return existing
|
||
raise
|
||
|
||
# 将开场白写入首条消息(作为恋人第一条)
|
||
opening = lover.opening_line or "嗨,我是你的恋人,很高兴见到你~"
|
||
now = datetime.utcnow()
|
||
first_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="lover",
|
||
content_type="text",
|
||
content=opening,
|
||
seq=1,
|
||
created_at=now,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
)
|
||
db.add(first_msg)
|
||
session.last_message_at = now
|
||
db.add(session)
|
||
db.flush()
|
||
return session
|
||
|
||
|
||
def _strip_inner_voice_text(raw: Optional[str]) -> str:
|
||
"""
|
||
去除括号中的内心/旁白内容,便于生成 TTS。
|
||
"""
|
||
if not raw:
|
||
return ""
|
||
# 去掉全角括号内的内容(常用于心声/旁白)
|
||
cleaned = re.sub(r"(.*?)", "", raw)
|
||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||
return cleaned.strip()
|
||
|
||
|
||
def _pick_available_voice(db: Session, lover: Lover, user_row: User) -> VoiceLibrary:
|
||
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
|
||
candidate: Optional[VoiceLibrary] = None
|
||
|
||
if lover.voice_id:
|
||
voice = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.id == lover.voice_id, VoiceLibrary.gender == lover.gender)
|
||
.first()
|
||
)
|
||
if voice:
|
||
# 付费音色需已拥有;免费/默认可直接使用
|
||
if (voice.price_gold or 0) > 0 and int(voice.id) not in owned_ids:
|
||
candidate = None
|
||
else:
|
||
candidate = voice
|
||
|
||
if not candidate:
|
||
candidate = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.gender == lover.gender, VoiceLibrary.is_default.is_(True))
|
||
.first()
|
||
)
|
||
if not candidate:
|
||
candidate = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.gender == lover.gender)
|
||
.order_by(VoiceLibrary.id.asc())
|
||
.first()
|
||
)
|
||
if not candidate:
|
||
raise HTTPException(status_code=404, detail="未找到可用音色")
|
||
if not candidate.voice_code:
|
||
raise HTTPException(status_code=500, detail="音色未配置 voice_code")
|
||
return candidate
|
||
|
||
|
||
def _upload_tts_to_oss(file_bytes: bytes, lover_id: int, message_id: int) -> str:
|
||
if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET:
|
||
raise HTTPException(status_code=500, detail="未配置 OSS Key")
|
||
if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT:
|
||
raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint")
|
||
|
||
object_name = f"lover/{lover_id}/tts/{message_id}.mp3"
|
||
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
|
||
try:
|
||
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
|
||
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
|
||
bucket.put_object(object_name, file_bytes)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail=f"上传语音失败: {exc}") from exc
|
||
|
||
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
|
||
if cdn:
|
||
return f"{cdn.rstrip('/')}/{object_name}"
|
||
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
|
||
|
||
|
||
@router.get("/session/init", response_model=ApiResponse[SessionInitOut])
|
||
def init_session(
|
||
page: int = Query(1, ge=1),
|
||
size: int = Query(15, ge=1, le=100),
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
|
||
|
||
session = _get_or_create_primary_session(db, user, lover)
|
||
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
_reset_chat_quota(user_row)
|
||
|
||
messages_query = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.order_by(ChatMessage.seq.desc())
|
||
)
|
||
total = messages_query.count()
|
||
records = messages_query.offset((page - 1) * size).limit(size).all()
|
||
# 逆序为正序输出
|
||
records = list(sorted(records, key=lambda m: m.seq or 0))
|
||
|
||
def to_message(msg: ChatMessage) -> MessageOut:
|
||
return MessageOut(
|
||
id=msg.id,
|
||
role=msg.role,
|
||
content=msg.content,
|
||
seq=msg.seq or 0,
|
||
created_at=msg.created_at,
|
||
content_type=getattr(msg, "content_type", None),
|
||
extra=getattr(msg, "extra", None),
|
||
tts_url=getattr(msg, "tts_url", None),
|
||
tts_status=getattr(msg, "tts_status", None),
|
||
tts_voice_id=getattr(msg, "tts_voice_id", None),
|
||
tts_model_id=getattr(msg, "tts_model_id", None),
|
||
)
|
||
|
||
inner_voice = (
|
||
session.inner_voice_enabled
|
||
if session.inner_voice_enabled is not None
|
||
else user_row.inner_voice_enabled
|
||
if user_row.inner_voice_enabled is not None
|
||
else settings.INNER_VOICE_DEFAULT
|
||
)
|
||
# 若用户已非 VIP,但会话心声仍为开启,则自动关闭保持一致
|
||
if inner_voice and not _is_vip(user_row):
|
||
inner_voice = False
|
||
session.inner_voice_enabled = False
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
return success_response(
|
||
SessionInitOut(
|
||
session_id=session.id,
|
||
messages=[to_message(m) for m in records],
|
||
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
|
||
chat_used_today=user_row.chat_used_today or 0,
|
||
inner_voice_enabled=bool(inner_voice),
|
||
is_vip=_is_vip(user_row),
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/messages", response_model=ApiResponse[MessagesPageOut])
|
||
def list_messages(
|
||
session_id: int,
|
||
page: int = Query(1, ge=1),
|
||
size: int = Query(15, ge=1, le=100),
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
session = _fetch_session(db, user.id, session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
query = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.order_by(ChatMessage.id.desc())
|
||
)
|
||
total = query.count()
|
||
records = query.offset((page - 1) * size).limit(size).all()
|
||
# 每页内部按 seq 正序返回
|
||
records = list(sorted(records, key=lambda m: m.seq or 0))
|
||
|
||
def to_message(msg: ChatMessage) -> MessageOut:
|
||
return MessageOut(
|
||
id=msg.id,
|
||
role=msg.role,
|
||
content=msg.content,
|
||
seq=msg.seq or 0,
|
||
created_at=msg.created_at,
|
||
content_type=getattr(msg, "content_type", None),
|
||
extra=getattr(msg, "extra", None),
|
||
tts_url=getattr(msg, "tts_url", None),
|
||
tts_status=getattr(msg, "tts_status", None),
|
||
tts_voice_id=getattr(msg, "tts_voice_id", None),
|
||
tts_model_id=getattr(msg, "tts_model_id", None),
|
||
)
|
||
|
||
has_more = page * size < total
|
||
return success_response(
|
||
MessagesPageOut(
|
||
session_id=session.id,
|
||
page=page,
|
||
size=size,
|
||
messages=[to_message(m) for m in records],
|
||
has_more=has_more,
|
||
)
|
||
)
|
||
|
||
|
||
@router.post("/messages/tts/{message_id}", response_model=ApiResponse[TTSOut])
|
||
def generate_message_tts(
|
||
message_id: int,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.id == message_id, ChatMessage.user_id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not msg:
|
||
raise HTTPException(status_code=404, detail="消息不存在")
|
||
if msg.role != "lover" or msg.content_type != "text":
|
||
raise HTTPException(status_code=400, detail="仅支持恋人文本消息生成语音")
|
||
|
||
lover = db.query(Lover).filter(Lover.id == msg.lover_id, Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人不存在")
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
clean_text = _strip_inner_voice_text(msg.content or "")
|
||
if not clean_text:
|
||
raise HTTPException(status_code=400, detail="消息内容为空,无法生成语音")
|
||
if len(clean_text) > 80:
|
||
raise HTTPException(status_code=400, detail="文本长度超过80字符,不支持生成语音")
|
||
|
||
if msg.tts_url and msg.tts_status == "succeeded":
|
||
return success_response(
|
||
TTSOut(
|
||
message_id=msg.id,
|
||
tts_url=msg.tts_url,
|
||
tts_status=msg.tts_status or "succeeded",
|
||
tts_voice_id=getattr(msg, "tts_voice_id", None),
|
||
tts_model_id=getattr(msg, "tts_model_id", None),
|
||
)
|
||
)
|
||
|
||
voice = _pick_available_voice(db, lover, user_row)
|
||
model = voice.tts_model_id or "cosyvoice-v2"
|
||
|
||
# 先标记 pending,避免并发重复生成
|
||
msg.tts_status = "pending"
|
||
msg.tts_error = None
|
||
db.add(msg)
|
||
db.flush()
|
||
|
||
try:
|
||
audio_bytes, fmt_name = synthesize(
|
||
clean_text,
|
||
model=model,
|
||
voice=voice.voice_code,
|
||
)
|
||
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id)
|
||
except HTTPException as exc:
|
||
detail_text = str(exc.detail) if hasattr(exc, "detail") else str(exc)
|
||
fallback_done = False
|
||
url = ""
|
||
if "ModelNotFound" in detail_text or "not found" in detail_text:
|
||
fallback_voice = settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
|
||
fallback_model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
|
||
if fallback_voice != voice.voice_code or fallback_model != model:
|
||
try:
|
||
audio_bytes, fmt_name = synthesize(
|
||
clean_text,
|
||
model=fallback_model,
|
||
voice=fallback_voice,
|
||
)
|
||
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id)
|
||
msg.tts_voice_id = None # 兜底音色不绑定库ID
|
||
msg.tts_model_id = fallback_model
|
||
fallback_done = True
|
||
except Exception:
|
||
url = ""
|
||
if not url:
|
||
msg.tts_status = "failed"
|
||
msg.tts_error = detail_text[:255]
|
||
db.add(msg)
|
||
db.flush()
|
||
raise
|
||
except Exception as exc:
|
||
msg.tts_status = "failed"
|
||
msg.tts_error = str(exc)[:255]
|
||
db.add(msg)
|
||
db.flush()
|
||
raise HTTPException(status_code=502, detail="语音生成失败,请稍后重试") from exc
|
||
|
||
msg.tts_url = url
|
||
msg.tts_status = "succeeded"
|
||
msg.tts_voice_id = msg.tts_voice_id or voice.id
|
||
msg.tts_model_id = msg.tts_model_id or model
|
||
msg.tts_format = fmt_name
|
||
msg.tts_error = None
|
||
db.add(msg)
|
||
db.flush()
|
||
|
||
return success_response(
|
||
TTSOut(
|
||
message_id=msg.id,
|
||
tts_url=msg.tts_url,
|
||
tts_status=msg.tts_status or "succeeded",
|
||
tts_voice_id=getattr(msg, "tts_voice_id", None),
|
||
tts_model_id=getattr(msg, "tts_model_id", None),
|
||
)
|
||
)
|
||
|
||
|
||
class ChatSendImageIn(BaseModel):
|
||
session_id: int
|
||
image_url: str
|
||
|
||
|
||
@router.post("/send-image", response_model=ApiResponse[ChatSendImageOut])
|
||
def send_image_message(
|
||
payload: ChatSendImageIn,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
img_url = payload.image_url or ""
|
||
if not img_url.startswith("http"):
|
||
raise HTTPException(status_code=400, detail="图片URL格式不正确")
|
||
# Lover 必须存在
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
|
||
|
||
user_row = (
|
||
db.query(User)
|
||
.filter(User.id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
# VIP 校验
|
||
if not _is_vip(user_row):
|
||
raise HTTPException(status_code=403, detail="仅 VIP 用户可发送图片")
|
||
_ensure_quota(user_row) # 额度检查,图片也计入
|
||
|
||
# 必须传入已有会话
|
||
session = _fetch_session(db, user.id, payload.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 禁止连发:最近一条如果是用户,则需等待回复
|
||
last_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.with_for_update()
|
||
.order_by(ChatMessage.seq.desc())
|
||
.first()
|
||
)
|
||
if last_msg and last_msg.role == "user":
|
||
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
|
||
|
||
# 新序号
|
||
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
|
||
now = datetime.utcnow()
|
||
user_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="user",
|
||
content_type="image",
|
||
content=img_url,
|
||
seq=next_seq,
|
||
created_at=now,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
extra={},
|
||
)
|
||
db.add(user_msg)
|
||
db.flush()
|
||
|
||
# 图片理解
|
||
caption = "图片已收到"
|
||
try:
|
||
caption = describe_image(img_url)
|
||
except HTTPException as exc:
|
||
# 保留失败信息,但不阻断对话
|
||
user_msg.extra = {"image_caption": caption, "image_error": str(exc.detail)}
|
||
else:
|
||
user_msg.extra = {"image_caption": caption}
|
||
db.add(user_msg)
|
||
db.flush()
|
||
|
||
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
|
||
if inner_voice_enabled and not _is_vip(user_row):
|
||
inner_voice_enabled = False
|
||
session.inner_voice_enabled = False
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
|
||
|
||
llm_result = chat_completion(
|
||
context_messages,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
seed=_pick_llm_seed(),
|
||
)
|
||
reply_text = llm_result.content or ""
|
||
# 关闭心声时去掉括号旁白,避免“内心戏”
|
||
if not inner_voice_enabled:
|
||
reply_text = _strip_inner_voice_text(reply_text)
|
||
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
|
||
if max_chars > 0 and len(reply_text) > max_chars:
|
||
reply_text = reply_text[:max_chars]
|
||
if _is_exact_repeat(last_msg, reply_text):
|
||
retry_messages = list(context_messages)
|
||
retry_messages.append(
|
||
{
|
||
"role": "system",
|
||
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
|
||
}
|
||
)
|
||
retry_result = chat_completion(
|
||
retry_messages,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
seed=_pick_llm_seed(),
|
||
)
|
||
retry_text = retry_result.content or ""
|
||
if not inner_voice_enabled:
|
||
retry_text = _strip_inner_voice_text(retry_text)
|
||
if max_chars > 0 and len(retry_text) > max_chars:
|
||
retry_text = retry_text[:max_chars]
|
||
if retry_text and not _is_exact_repeat(last_msg, retry_text):
|
||
llm_result = retry_result
|
||
reply_text = retry_text
|
||
else:
|
||
base_text = last_msg.content or reply_text
|
||
reply_text = f"我刚才说的是:{base_text}"
|
||
if max_chars > 0 and len(reply_text) > max_chars:
|
||
reply_text = reply_text[:max_chars]
|
||
|
||
lover_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="lover",
|
||
content_type="text",
|
||
content=reply_text,
|
||
seq=next_seq + 1,
|
||
created_at=datetime.utcnow(),
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
token_input=(llm_result.usage or {}).get("input_tokens"),
|
||
token_output=(llm_result.usage or {}).get("output_tokens"),
|
||
)
|
||
db.add(lover_msg)
|
||
|
||
# 更新会话与额度
|
||
session.last_message_at = datetime.utcnow()
|
||
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
|
||
db.add(session)
|
||
db.add(user_row)
|
||
db.flush()
|
||
|
||
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
|
||
|
||
return success_response(
|
||
ChatSendImageOut(
|
||
session_id=session.id,
|
||
user_message_id=user_msg.id,
|
||
lover_message_id=lover_msg.id,
|
||
reply=reply_text,
|
||
inner_voice_enabled=bool(inner_voice_enabled),
|
||
usage=llm_result.usage,
|
||
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
|
||
chat_used_today=user_row.chat_used_today or 0,
|
||
)
|
||
)
|
||
|
||
|
||
def _maybe_generate_summary(db: Session, session: ChatSession, upto_seq: int):
|
||
"""
|
||
若累计新增消息达到阈值,则生成滚动摘要写入 nf_chat_summary。
|
||
"""
|
||
threshold = settings.CHAT_SUMMARY_TRIGGER_MSGS or 0
|
||
token_threshold = settings.CHAT_SUMMARY_TRIGGER_TOKENS or 0
|
||
if threshold <= 0 and token_threshold <= 0:
|
||
return
|
||
last_summary = (
|
||
db.query(ChatSummary)
|
||
.filter(ChatSummary.session_id == session.id)
|
||
.order_by(ChatSummary.upto_seq.desc())
|
||
.first()
|
||
)
|
||
last_upto = last_summary.upto_seq if last_summary else 0
|
||
new_msgs = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id, ChatMessage.seq > last_upto, ChatMessage.seq <= upto_seq)
|
||
.order_by(ChatMessage.seq.asc())
|
||
.all()
|
||
)
|
||
if threshold > 0 and len(new_msgs) < threshold:
|
||
# 条数未达阈值
|
||
pass
|
||
else:
|
||
# 若条数阈值已满足或未设置,则继续
|
||
pass
|
||
|
||
# 粗略估算 token(按英文 4 chars/token,中文 1.5 chars/token 简化)
|
||
def estimate_tokens(text: str) -> int:
|
||
if not text:
|
||
return 0
|
||
# 假设每 2 字符 ~ 1 token(中英文混合粗估)
|
||
return max(1, len(text) // 2)
|
||
|
||
total_tokens = sum(estimate_tokens(m.content or "") for m in new_msgs)
|
||
if threshold > 0 and len(new_msgs) < threshold and (token_threshold <= 0 or total_tokens < token_threshold):
|
||
return
|
||
if token_threshold > 0 and total_tokens < token_threshold and (threshold <= 0 or len(new_msgs) < threshold):
|
||
return
|
||
# 组装摘要输入
|
||
lines = []
|
||
for m in new_msgs:
|
||
speaker = "用户" if m.role == "user" else "恋人" if m.role == "lover" else "系统"
|
||
lines.append(f"{speaker}: {m.content}")
|
||
prompt = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是对话摘要助手,请用简洁中文总结关键事件、情绪、约定与待办,50~120字。",
|
||
},
|
||
{"role": "user", "content": "\n".join(lines)},
|
||
]
|
||
result = chat_completion(prompt, temperature=0.3, max_tokens=300)
|
||
summary = ChatSummary(
|
||
session_id=session.id,
|
||
upto_seq=upto_seq,
|
||
summary_text=result.content or "",
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
token_input=(result.usage or {}).get("input_tokens"),
|
||
token_output=(result.usage or {}).get("output_tokens"),
|
||
created_at=datetime.utcnow(),
|
||
)
|
||
db.add(summary)
|
||
db.flush()
|
||
|
||
|
||
@router.post("/send", response_model=ApiResponse[ChatSendOut])
|
||
def send_message(
|
||
payload: ChatSendIn,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
# Lover 必须存在
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
|
||
|
||
# 用户额度检查
|
||
user_row = (
|
||
db.query(User)
|
||
.filter(User.id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
_ensure_quota(user_row)
|
||
|
||
# 必须传入已有会话
|
||
session = _fetch_session(db, user.id, payload.session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 禁止连发:最近一条如果是用户,则需等待回复
|
||
last_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.with_for_update()
|
||
.order_by(ChatMessage.seq.desc())
|
||
.first()
|
||
)
|
||
if last_msg and last_msg.role == "user":
|
||
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
|
||
|
||
# 新序号
|
||
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
|
||
|
||
now = datetime.utcnow()
|
||
user_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="user",
|
||
content_type="text",
|
||
content=payload.message,
|
||
seq=next_seq,
|
||
created_at=now,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
)
|
||
db.add(user_msg)
|
||
db.flush()
|
||
|
||
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
|
||
# 非 VIP 则强制关闭心声,避免历史开关残留
|
||
if inner_voice_enabled and not _is_vip(user_row):
|
||
inner_voice_enabled = False
|
||
session.inner_voice_enabled = False
|
||
db.add(session)
|
||
db.flush()
|
||
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
|
||
|
||
llm_result = chat_completion(
|
||
context_messages,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
seed=_pick_llm_seed(),
|
||
)
|
||
reply_text = llm_result.content or ""
|
||
if not inner_voice_enabled:
|
||
reply_text = _strip_inner_voice_text(reply_text)
|
||
# 字数上限控制(0 表示不截断)
|
||
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
|
||
if max_chars > 0 and len(reply_text) > max_chars:
|
||
reply_text = reply_text[:max_chars]
|
||
if _is_exact_repeat(last_msg, reply_text):
|
||
retry_messages = list(context_messages)
|
||
retry_messages.append(
|
||
{
|
||
"role": "system",
|
||
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
|
||
}
|
||
)
|
||
retry_result = chat_completion(
|
||
retry_messages,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
seed=_pick_llm_seed(),
|
||
)
|
||
retry_text = retry_result.content or ""
|
||
if not inner_voice_enabled:
|
||
retry_text = _strip_inner_voice_text(retry_text)
|
||
if max_chars > 0 and len(retry_text) > max_chars:
|
||
retry_text = retry_text[:max_chars]
|
||
if retry_text and not _is_exact_repeat(last_msg, retry_text):
|
||
llm_result = retry_result
|
||
reply_text = retry_text
|
||
else:
|
||
base_text = last_msg.content or reply_text
|
||
reply_text = f"我刚才说的是:{base_text}"
|
||
if max_chars > 0 and len(reply_text) > max_chars:
|
||
reply_text = reply_text[:max_chars]
|
||
|
||
lover_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="lover",
|
||
content_type="text",
|
||
content=reply_text,
|
||
seq=next_seq + 1,
|
||
created_at=datetime.utcnow(),
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
token_input=(llm_result.usage or {}).get("input_tokens"),
|
||
token_output=(llm_result.usage or {}).get("output_tokens"),
|
||
)
|
||
db.add(lover_msg)
|
||
|
||
# 更新会话与额度
|
||
session.last_message_at = datetime.utcnow()
|
||
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
|
||
db.add(session)
|
||
db.add(user_row)
|
||
db.flush()
|
||
|
||
# 触发滚动摘要
|
||
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
|
||
|
||
return success_response(
|
||
ChatSendOut(
|
||
session_id=session.id,
|
||
user_message_id=user_msg.id,
|
||
lover_message_id=lover_msg.id,
|
||
reply=reply_text,
|
||
inner_voice_enabled=bool(inner_voice_enabled),
|
||
usage=llm_result.usage,
|
||
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
|
||
chat_used_today=user_row.chat_used_today or 0,
|
||
)
|
||
)
|
||
|
||
|
||
class InnerVoiceIn(BaseModel):
|
||
session_id: int
|
||
enabled: bool
|
||
|
||
|
||
@router.put("/inner-voice", response_model=ApiResponse[ChatConfigOut])
|
||
def update_inner_voice(
|
||
payload: InnerVoiceIn,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
session = _fetch_session(db, user.id, payload.session_id)
|
||
# 如果会话不存在,且用户希望开启/关闭心声,则自动创建会话以便落开关
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
# 仅 VIP 可开启心声
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if payload.enabled:
|
||
if not user_row or not _is_vip(user_row):
|
||
raise HTTPException(status_code=403, detail="仅 VIP 用户可开启心声")
|
||
session.inner_voice_enabled = payload.enabled
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
_reset_chat_quota(user_row) # refresh today count if needed
|
||
return success_response(
|
||
ChatConfigOut(
|
||
session_id=session.id,
|
||
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
|
||
chat_used_today=user_row.chat_used_today or 0,
|
||
inner_voice_enabled=payload.enabled,
|
||
is_vip=_is_vip(user_row),
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/config", response_model=ApiResponse[ChatConfigOut])
|
||
def get_chat_config(
|
||
session_id: Optional[int] = Query(default=None),
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
session = _fetch_session(db, user.id, session_id) if session_id else None
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
_reset_chat_quota(user_row)
|
||
inner_voice = (
|
||
session.inner_voice_enabled
|
||
if session and session.inner_voice_enabled is not None
|
||
else user_row.inner_voice_enabled
|
||
if user_row.inner_voice_enabled is not None
|
||
else settings.INNER_VOICE_DEFAULT
|
||
)
|
||
if inner_voice and not _is_vip(user_row):
|
||
inner_voice = False
|
||
if session:
|
||
session.inner_voice_enabled = False
|
||
db.add(session)
|
||
db.flush()
|
||
return success_response(
|
||
ChatConfigOut(
|
||
session_id=session.id if session else None,
|
||
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
|
||
chat_used_today=user_row.chat_used_today or 0,
|
||
inner_voice_enabled=bool(inner_voice),
|
||
is_vip=_is_vip(user_row),
|
||
)
|
||
)
|
||
|
||
|
||
|
||
# ==================== 消息编辑功能 ====================
|
||
|
||
class EditMessageIn(BaseModel):
|
||
new_content: str = Field(..., min_length=1, max_length=4000, description="新的消息内容")
|
||
regenerate_reply: bool = Field(default=True, description="是否重新生成AI回复")
|
||
|
||
|
||
class EditMessageOut(BaseModel):
|
||
message_id: int
|
||
new_content: str
|
||
regenerated: bool
|
||
new_reply: Optional[str] = None
|
||
|
||
|
||
@router.put("/messages/{message_id}/edit", response_model=ApiResponse[EditMessageOut])
|
||
def edit_message(
|
||
message_id: int,
|
||
payload: EditMessageIn,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
编辑历史消息并可选重新生成AI回复
|
||
"""
|
||
# 1. 查找消息
|
||
msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.id == message_id, ChatMessage.user_id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not msg:
|
||
raise HTTPException(status_code=404, detail="消息不存在")
|
||
|
||
# 允许编辑AI消息(lover角色)
|
||
if msg.role not in ["user", "lover"]:
|
||
raise HTTPException(status_code=400, detail="只能编辑用户或AI消息")
|
||
|
||
# 2. 保存原始内容(首次编辑时)
|
||
if not msg.is_edited:
|
||
msg.original_content = msg.content
|
||
|
||
# 3. 更新消息
|
||
msg.content = payload.new_content
|
||
msg.is_edited = True
|
||
msg.edited_at = datetime.utcnow()
|
||
db.add(msg)
|
||
db.flush()
|
||
|
||
# 4. 删除该消息之后的摘要(需要重新生成)
|
||
db.query(ChatSummary).filter(
|
||
ChatSummary.session_id == msg.session_id,
|
||
ChatSummary.upto_seq >= msg.seq
|
||
).delete()
|
||
db.flush()
|
||
|
||
new_reply = None
|
||
regenerated = False
|
||
|
||
# 5. 重新生成AI回复(仅当编辑用户消息时)
|
||
if payload.regenerate_reply and msg.role == "user":
|
||
# 找到紧接着的AI回复
|
||
next_lover_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(
|
||
ChatMessage.session_id == msg.session_id,
|
||
ChatMessage.seq == msg.seq + 1,
|
||
ChatMessage.role == "lover"
|
||
)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
|
||
if next_lover_msg:
|
||
# 获取会话和恋人信息
|
||
session = db.query(ChatSession).filter(ChatSession.id == msg.session_id).first()
|
||
lover = db.query(Lover).filter(Lover.id == msg.lover_id).first()
|
||
|
||
if session and lover:
|
||
# 构建上下文(使用修改后的消息)
|
||
inner_voice_enabled = session.inner_voice_enabled or False
|
||
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
|
||
|
||
# 调用LLM重新生成
|
||
llm_result = chat_completion(
|
||
context_messages,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
seed=_pick_llm_seed(),
|
||
)
|
||
|
||
new_reply = llm_result.content or ""
|
||
if not inner_voice_enabled:
|
||
new_reply = _strip_inner_voice_text(new_reply)
|
||
|
||
# 更新AI回复
|
||
next_lover_msg.content = new_reply
|
||
next_lover_msg.is_edited = True
|
||
next_lover_msg.edited_at = datetime.utcnow()
|
||
next_lover_msg.token_input = (llm_result.usage or {}).get("input_tokens")
|
||
next_lover_msg.token_output = (llm_result.usage or {}).get("output_tokens")
|
||
db.add(next_lover_msg)
|
||
db.flush()
|
||
|
||
regenerated = True
|
||
|
||
# 触发摘要重新生成
|
||
_maybe_generate_summary(db, session, upto_seq=next_lover_msg.seq)
|
||
|
||
return success_response(
|
||
EditMessageOut(
|
||
message_id=msg.id,
|
||
new_content=msg.content,
|
||
regenerated=regenerated,
|
||
new_reply=new_reply,
|
||
)
|
||
)
|