Ai_GirlFriend/lover/routers/chat.py

1341 lines
46 KiB
Python
Raw Normal View History

2026-01-31 19:15:41 +08:00
from datetime import datetime, date
from typing import List, Optional
import re
import random
import oss2
from fastapi import APIRouter, Depends, HTTPException, Query, Request
2026-01-31 19:15:41 +08:00
from pydantic import BaseModel, Field
from sqlalchemy import desc
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from ..config import settings
from ..db import get_db
from ..deps import AuthedUser, get_current_user
from ..llm import chat_completion
from ..models import ChatMessage, ChatSession, Lover, User, ChatFact, ChatSummary, VoiceLibrary # type: ignore
from ..response import ApiResponse, success_response
from ..tts import synthesize
from ..vision import describe_image
from .config import _parse_owned_voices
router = APIRouter(prefix="/chat", tags=["chat"])
class ChatSendIn(BaseModel):
session_id: int = Field(..., description="会话ID必填")
message: str = Field(..., min_length=1, max_length=4000)
class ChatSendOut(BaseModel):
session_id: int
user_message_id: int
lover_message_id: int
reply: str
inner_voice_enabled: bool
usage: Optional[dict] = None
chat_limit_daily: int
chat_used_today: int
class ChatSendImageOut(BaseModel):
session_id: int
user_message_id: int
lover_message_id: int
reply: str
inner_voice_enabled: bool
usage: Optional[dict] = None
chat_limit_daily: int
chat_used_today: int
class ChatConfigOut(BaseModel):
session_id: Optional[int] = None
chat_limit_daily: int
chat_used_today: int
inner_voice_enabled: bool
is_vip: bool
class MessageOut(BaseModel):
id: int
role: str
content: str
seq: int
created_at: datetime
content_type: Optional[str] = None
extra: Optional[dict] = None
# 预留给前端展示是否已有 TTS避免重复请求
tts_url: Optional[str] = None
tts_status: Optional[str] = None
tts_voice_id: Optional[int] = None
tts_model_id: Optional[str] = None
class TTSOut(BaseModel):
message_id: int
tts_url: Optional[str] = None
tts_status: str
tts_voice_id: Optional[int] = None
tts_model_id: Optional[str] = None
class SessionInitOut(BaseModel):
session_id: int
messages: List[MessageOut]
chat_limit_daily: int
chat_used_today: int
inner_voice_enabled: bool
is_vip: bool
class MessagesPageOut(BaseModel):
session_id: int
page: int
size: int
messages: List[MessageOut]
has_more: bool
def _reset_chat_quota(user_row: User):
today = date.today()
if user_row.chat_reset_date != today:
user_row.chat_reset_date = today
user_row.chat_used_today = 0
def _ensure_quota(user_row: User):
_reset_chat_quota(user_row)
limit = user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY
if user_row.chat_used_today >= limit:
raise HTTPException(status_code=400, detail="今日聊天次数已用完")
def _pick_llm_seed() -> int:
return random.randint(0, 2**31 - 1)
def _normalize_text(text: Optional[str]) -> str:
if not text:
return ""
cleaned = re.sub(r"\s+", " ", text)
return cleaned.strip()
_CONTEXT_MSG_OVERHEAD_TOKENS = 4
def _estimate_tokens_for_context(text: Optional[str]) -> int:
if not text:
return 0
ascii_chars = 0
for ch in text:
if ord(ch) < 128:
ascii_chars += 1
non_ascii_chars = len(text) - ascii_chars
tokens = non_ascii_chars + (ascii_chars + 3) // 4
return max(1, tokens)
def _estimate_message_tokens(msg: dict) -> int:
content = msg.get("content") if isinstance(msg, dict) else ""
return _estimate_tokens_for_context(content or "") + _CONTEXT_MSG_OVERHEAD_TOKENS
def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
if not text or max_tokens <= 0:
return ""
if _estimate_tokens_for_context(text) <= max_tokens:
return text
approx_tokens = _estimate_tokens_for_context(text)
ratio = max_tokens / max(1, approx_tokens)
new_len = max(1, int(len(text) * ratio))
trimmed = text[:new_len]
for _ in range(6):
if _estimate_tokens_for_context(trimmed) <= max_tokens:
break
new_len = max(1, int(new_len * 0.85))
trimmed = text[:new_len]
return trimmed
def _truncate_message_to_budget(msg: dict, max_tokens: int) -> dict:
content = msg.get("content") if isinstance(msg, dict) else ""
content_budget = max_tokens - _CONTEXT_MSG_OVERHEAD_TOKENS
trimmed_content = _truncate_text_to_tokens(content or "", max(0, content_budget))
new_msg = dict(msg)
new_msg["content"] = trimmed_content
return new_msg
def _trim_context_messages(messages: List[dict]) -> List[dict]:
max_tokens = settings.CHAT_CONTEXT_MAX_TOKENS or 0
if max_tokens <= 0:
return messages
# Split system prefix and dialogue messages.
first_non_system = next(
(idx for idx, msg in enumerate(messages) if msg.get("role") != "system"),
len(messages),
)
prefix = list(messages[:first_non_system])
convo = list(messages[first_non_system:])
def total_tokens() -> int:
return sum(_estimate_message_tokens(m) for m in prefix) + sum(_estimate_message_tokens(m) for m in convo)
total = total_tokens()
if total <= max_tokens:
return messages
# Drop optional system messages from the end (keep the base system prompt).
while len(prefix) > 1 and total > max_tokens:
removed = prefix.pop()
total -= _estimate_message_tokens(removed)
# Drop oldest dialogue messages, keep at least the latest one.
if convo:
start = 0
while start < len(convo) - 1 and total > max_tokens:
total -= _estimate_message_tokens(convo[start])
start += 1
if start:
convo = convo[start:]
# If still over budget, truncate the earliest remaining dialogue message.
if convo and total > max_tokens:
prefix_tokens = sum(_estimate_message_tokens(m) for m in prefix)
tail_tokens = sum(_estimate_message_tokens(m) for m in convo[1:])
allowed_for_first = max_tokens - prefix_tokens - tail_tokens
convo[0] = _truncate_message_to_budget(convo[0], allowed_for_first)
total = total_tokens()
if prefix and total > max_tokens:
# Last resort: truncate base system prompt.
allowed_for_system = max_tokens - sum(_estimate_message_tokens(m) for m in convo)
prefix[0] = _truncate_message_to_budget(prefix[0], allowed_for_system)
return prefix + convo
def _is_exact_repeat(last_msg: Optional[ChatMessage], reply_text: str) -> bool:
if not last_msg or last_msg.role != "lover":
return False
return _normalize_text(last_msg.content or "") == _normalize_text(reply_text)
def _fetch_session(db: Session, user_id: int, session_id: Optional[int]) -> Optional[ChatSession]:
if not session_id:
return None
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user_id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
return session
def _create_session(db: Session, user: AuthedUser, lover: Lover, inner_voice_enabled: bool) -> ChatSession:
now = datetime.utcnow()
session = ChatSession(
user_id=user.id,
lover_id=lover.id,
model=settings.LLM_MODEL or "qwen-flash",
status="active",
last_message_at=now,
created_at=now,
updated_at=now,
inner_voice_enabled=inner_voice_enabled,
)
db.add(session)
db.flush()
return session
def _build_system_prompt(lover: Lover, user: AuthedUser, inner_voice_enabled: bool) -> str:
# 基础设定:保持恋人角色
parts = [
f"你是用户 {user.nickname or '用户'} 的虚拟恋人,请用亲密、温暖的语气聊天。",
"禁止涉政、违法、暴力、未成年相关内容。",
]
if lover.personality_prompt:
parts.append(f"人格设定:{lover.personality_prompt}")
parts.append("避免与上一条回复完全一致;若用户要求复述,请换一种说法简要转述。")
if inner_voice_enabled:
parts.append("在适当时机用全角括号(…)加入心理活动,正文与心声并存。")
else:
parts.append("不要输出括号(…)内的心理活动。")
return "\n".join(parts)
def _build_context_messages(
db: Session,
session: ChatSession,
lover: Lover,
user: AuthedUser,
inner_voice_enabled: bool,
) -> List[dict]:
messages: List[dict] = []
# 系统提示
messages.append({"role": "system", "content": _build_system_prompt(lover, user, inner_voice_enabled)})
# 最新摘要
summary = (
db.query(ChatSummary)
.filter(ChatSummary.session_id == session.id)
.order_by(ChatSummary.upto_seq.desc())
.first()
)
if summary:
messages.append({"role": "system", "content": f"对话摘要(累积至 seq {summary.upto_seq}{summary.summary_text}"})
# 画像要点(按权重、时间排序)
facts = (
db.query(ChatFact)
.filter(ChatFact.user_id == user.id, ChatFact.lover_id == lover.id)
.order_by(desc(ChatFact.weight), ChatFact.created_at.desc())
.limit(settings.CHAT_FACT_MAX_ROWS)
.all()
)
if facts:
fact_lines = [f"- {f.kind or 'fact'}: {f.content}" for f in facts]
messages.append({"role": "system", "content": "画像与重要事实:\n" + "\n".join(fact_lines)})
# 最近消息窗口(含最新用户消息)
history = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.seq.desc())
.limit(settings.CHAT_RECENT_WINDOW)
.all()
)
# 逆序结果,保持时间正序拼装
for msg in reversed(history):
role = "assistant" if msg.role == "lover" else "user" if msg.role == "user" else "system"
# 将图片消息转成文本描述放入上下文
if msg.content_type == "image":
caption = None
try:
caption = (msg.extra or {}).get("image_caption")
except Exception:
caption = None
if msg.role == "user":
user_nickname = user.nickname or "用户"
content = f"{user_nickname}给你发了一张图片,内容是:{caption or '图片已收到,但未能识别'}"
else:
content = f"恋人发送了一张图片,内容是:{caption or '图片'}"
else:
content = msg.content
# 如果消息被编辑过添加特殊标记提醒AI
if msg.is_edited and msg.original_content:
if msg.role == "lover":
# AI的回复被用户纠正
content = f"[用户纠正] {content}\n(原回复:{msg.original_content}\n注意:用户认为原回复不准确,已纠正为上述内容,请以用户纠正的内容为准。"
else:
# 用户消息被编辑
content = f"{content}\n[已编辑,原内容:{msg.original_content}]"
2026-01-31 19:15:41 +08:00
messages.append({"role": role, "content": content})
return _trim_context_messages(messages)
def _is_vip(user_row: User) -> bool:
try:
now_ts = int(datetime.utcnow().timestamp())
return bool(user_row.vip_endtime and int(user_row.vip_endtime) > now_ts)
except Exception:
return False
def _get_or_create_primary_session(db: Session, user: AuthedUser, lover: Lover) -> ChatSession:
"""
获取用户当前唯一 active 会话若不存在则创建
若存在多个 active会选择最新的一个并将其他标记为 archived避免多 session
"""
active_sessions = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
.with_for_update()
.order_by(ChatSession.created_at.desc())
.all()
)
if active_sessions:
primary = active_sessions[0]
others = active_sessions[1:]
for s in others:
s.status = "archived"
db.add(s)
db.flush()
return primary
# 创建新会话,心声开关取用户默认或全局默认
user_row = db.query(User).filter(User.id == user.id).first()
inner_voice = (
user_row.inner_voice_enabled
if user_row and user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
try:
session = _create_session(db, user, lover, inner_voice_enabled=inner_voice)
except IntegrityError:
# 并发创建或存在归档会话命中唯一键时,回滚并返回已存在的会话
db.rollback()
existing = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id)
.with_for_update()
.order_by(ChatSession.created_at.desc())
.first()
)
if existing:
if existing.status != "active":
existing.status = "active"
existing.updated_at = datetime.utcnow()
db.add(existing)
db.flush()
return existing
raise
# 将开场白写入首条消息(作为恋人第一条)
opening = lover.opening_line or "嗨,我是你的恋人,很高兴见到你~"
now = datetime.utcnow()
first_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=opening,
seq=1,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(first_msg)
session.last_message_at = now
db.add(session)
db.flush()
return session
def _strip_inner_voice_text(raw: Optional[str]) -> str:
"""
去除括号中的内心/旁白内容便于生成 TTS
"""
if not raw:
return ""
# 去掉全角括号内的内容(常用于心声/旁白)
cleaned = re.sub(r".*?", "", raw)
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
def _pick_available_voice(db: Session, lover: Lover, user_row: User) -> VoiceLibrary:
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
candidate: Optional[VoiceLibrary] = None
if lover.voice_id:
voice = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.id == lover.voice_id, VoiceLibrary.gender == lover.gender)
.first()
)
if voice:
# 付费音色需已拥有;免费/默认可直接使用
if (voice.price_gold or 0) > 0 and int(voice.id) not in owned_ids:
candidate = None
else:
candidate = voice
if not candidate:
candidate = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == lover.gender, VoiceLibrary.is_default.is_(True))
.first()
)
if not candidate:
candidate = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == lover.gender)
.order_by(VoiceLibrary.id.asc())
.first()
)
if not candidate:
raise HTTPException(status_code=404, detail="未找到可用音色")
if not candidate.voice_code:
raise HTTPException(status_code=500, detail="音色未配置 voice_code")
return candidate
def _upload_tts_to_oss(file_bytes: bytes, lover_id: int, message_id: int, request: Request = None) -> str:
"""
上传 TTS 音频文件
优先使用 OSS如果未配置则保存到本地
"""
# 检查是否配置了 OSS
has_oss = (
settings.ALIYUN_OSS_ACCESS_KEY_ID
and settings.ALIYUN_OSS_ACCESS_KEY_SECRET
and settings.ALIYUN_OSS_BUCKET_NAME
and settings.ALIYUN_OSS_ENDPOINT
)
if has_oss:
# 使用 OSS 上传
object_name = f"lover/{lover_id}/tts/{message_id}.mp3"
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
try:
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
bucket.put_object(object_name, file_bytes)
# 返回 OSS URL
cdn_domain = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn_domain:
return f"{cdn_domain.rstrip('/')}/{object_name}"
else:
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint}/{object_name}"
except Exception as exc:
raise HTTPException(status_code=502, detail=f"上传语音失败: {exc}") from exc
else:
# 保存到本地文件系统
import os
from pathlib import Path
# 创建保存目录
base_dir = Path("public/tts")
lover_dir = base_dir / str(lover_id)
lover_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = lover_dir / f"{message_id}.mp3"
try:
with open(file_path, "wb") as f:
f.write(file_bytes)
# 自动检测请求来源,生成正确的 URL
if request:
# 从请求头获取 Host
host = request.headers.get("host", "127.0.0.1:8000")
scheme = "https" if request.url.scheme == "https" else "http"
backend_url = f"{scheme}://{host}"
else:
# 降级使用环境变量
backend_url = os.getenv("BACKEND_URL", "http://127.0.0.1:8000")
return f"{backend_url.rstrip('/')}/tts/{lover_id}/{message_id}.mp3"
except Exception as exc:
raise HTTPException(status_code=500, detail=f"保存语音文件失败: {exc}") from exc
2026-01-31 19:15:41 +08:00
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn:
return f"{cdn.rstrip('/')}/{object_name}"
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
@router.get("/session/init", response_model=ApiResponse[SessionInitOut])
def init_session(
page: int = Query(1, ge=1),
size: int = Query(15, ge=1, le=100),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
session = _get_or_create_primary_session(db, user, lover)
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_reset_chat_quota(user_row)
messages_query = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.seq.desc())
)
total = messages_query.count()
records = messages_query.offset((page - 1) * size).limit(size).all()
# 逆序为正序输出
records = list(sorted(records, key=lambda m: m.seq or 0))
def to_message(msg: ChatMessage) -> MessageOut:
return MessageOut(
id=msg.id,
role=msg.role,
content=msg.content,
seq=msg.seq or 0,
created_at=msg.created_at,
content_type=getattr(msg, "content_type", None),
extra=getattr(msg, "extra", None),
tts_url=getattr(msg, "tts_url", None),
tts_status=getattr(msg, "tts_status", None),
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
inner_voice = (
session.inner_voice_enabled
if session.inner_voice_enabled is not None
else user_row.inner_voice_enabled
if user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
# 若用户已非 VIP但会话心声仍为开启则自动关闭保持一致
if inner_voice and not _is_vip(user_row):
inner_voice = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
return success_response(
SessionInitOut(
session_id=session.id,
messages=[to_message(m) for m in records],
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=bool(inner_voice),
is_vip=_is_vip(user_row),
)
)
@router.get("/messages", response_model=ApiResponse[MessagesPageOut])
def list_messages(
session_id: int,
page: int = Query(1, ge=1),
size: int = Query(15, ge=1, le=100),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
query = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.id.desc())
)
total = query.count()
records = query.offset((page - 1) * size).limit(size).all()
# 每页内部按 seq 正序返回
records = list(sorted(records, key=lambda m: m.seq or 0))
def to_message(msg: ChatMessage) -> MessageOut:
return MessageOut(
id=msg.id,
role=msg.role,
content=msg.content,
seq=msg.seq or 0,
created_at=msg.created_at,
content_type=getattr(msg, "content_type", None),
extra=getattr(msg, "extra", None),
tts_url=getattr(msg, "tts_url", None),
tts_status=getattr(msg, "tts_status", None),
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
has_more = page * size < total
return success_response(
MessagesPageOut(
session_id=session.id,
page=page,
size=size,
messages=[to_message(m) for m in records],
has_more=has_more,
)
)
@router.post("/messages/tts/{message_id}", response_model=ApiResponse[TTSOut])
def generate_message_tts(
message_id: int,
request: Request,
2026-01-31 19:15:41 +08:00
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == message_id, ChatMessage.user_id == user.id)
.with_for_update()
.first()
)
if not msg:
raise HTTPException(status_code=404, detail="消息不存在")
if msg.role != "lover" or msg.content_type != "text":
raise HTTPException(status_code=400, detail="仅支持恋人文本消息生成语音")
lover = db.query(Lover).filter(Lover.id == msg.lover_id, Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在")
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
clean_text = _strip_inner_voice_text(msg.content or "")
if not clean_text:
raise HTTPException(status_code=400, detail="消息内容为空,无法生成语音")
if len(clean_text) > 80:
raise HTTPException(status_code=400, detail="文本长度超过80字符不支持生成语音")
if msg.tts_url and msg.tts_status == "succeeded":
return success_response(
TTSOut(
message_id=msg.id,
tts_url=msg.tts_url,
tts_status=msg.tts_status or "succeeded",
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
)
voice = _pick_available_voice(db, lover, user_row)
model = voice.tts_model_id or "cosyvoice-v2"
# 先标记 pending避免并发重复生成
msg.tts_status = "pending"
msg.tts_error = None
db.add(msg)
db.flush()
try:
audio_bytes, fmt_name = synthesize(
clean_text,
model=model,
voice=voice.voice_code,
)
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id, request)
2026-01-31 19:15:41 +08:00
except HTTPException as exc:
detail_text = str(exc.detail) if hasattr(exc, "detail") else str(exc)
fallback_done = False
url = ""
if "ModelNotFound" in detail_text or "not found" in detail_text:
fallback_voice = settings.VOICE_CALL_TTS_VOICE or "longxiaochun_v2"
fallback_model = settings.VOICE_CALL_TTS_MODEL or "cosyvoice-v2"
if fallback_voice != voice.voice_code or fallback_model != model:
try:
audio_bytes, fmt_name = synthesize(
clean_text,
model=fallback_model,
voice=fallback_voice,
)
url = _upload_tts_to_oss(audio_bytes, lover.id, msg.id, request)
2026-01-31 19:15:41 +08:00
msg.tts_voice_id = None # 兜底音色不绑定库ID
msg.tts_model_id = fallback_model
fallback_done = True
except Exception:
url = ""
if not url:
msg.tts_status = "failed"
msg.tts_error = detail_text[:255]
db.add(msg)
db.flush()
raise
except Exception as exc:
msg.tts_status = "failed"
msg.tts_error = str(exc)[:255]
db.add(msg)
db.flush()
raise HTTPException(status_code=502, detail="语音生成失败,请稍后重试") from exc
msg.tts_url = url
msg.tts_status = "succeeded"
msg.tts_voice_id = msg.tts_voice_id or voice.id
msg.tts_model_id = msg.tts_model_id or model
msg.tts_format = fmt_name
msg.tts_error = None
db.add(msg)
db.flush()
return success_response(
TTSOut(
message_id=msg.id,
tts_url=msg.tts_url,
tts_status=msg.tts_status or "succeeded",
tts_voice_id=getattr(msg, "tts_voice_id", None),
tts_model_id=getattr(msg, "tts_model_id", None),
)
)
class ChatSendImageIn(BaseModel):
session_id: int
image_url: str
@router.post("/send-image", response_model=ApiResponse[ChatSendImageOut])
def send_image_message(
payload: ChatSendImageIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
img_url = payload.image_url or ""
if not img_url.startswith("http"):
raise HTTPException(status_code=400, detail="图片URL格式不正确")
# Lover 必须存在
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# VIP 校验
if not _is_vip(user_row):
raise HTTPException(status_code=403, detail="仅 VIP 用户可发送图片")
_ensure_quota(user_row) # 额度检查,图片也计入
# 必须传入已有会话
session = _fetch_session(db, user.id, payload.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 禁止连发:最近一条如果是用户,则需等待回复
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.with_for_update()
.order_by(ChatMessage.seq.desc())
.first()
)
if last_msg and last_msg.role == "user":
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
# 新序号
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
now = datetime.utcnow()
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="image",
content=img_url,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
extra={},
)
db.add(user_msg)
db.flush()
# 图片理解
caption = "图片已收到"
try:
caption = describe_image(img_url)
except HTTPException as exc:
# 保留失败信息,但不阻断对话
user_msg.extra = {"image_caption": caption, "image_error": str(exc.detail)}
else:
user_msg.extra = {"image_caption": caption}
db.add(user_msg)
db.flush()
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
if inner_voice_enabled and not _is_vip(user_row):
inner_voice_enabled = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
llm_result = chat_completion(
context_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
reply_text = llm_result.content or ""
# 关闭心声时去掉括号旁白,避免“内心戏”
if not inner_voice_enabled:
reply_text = _strip_inner_voice_text(reply_text)
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
if _is_exact_repeat(last_msg, reply_text):
retry_messages = list(context_messages)
retry_messages.append(
{
"role": "system",
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
}
)
retry_result = chat_completion(
retry_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
retry_text = retry_result.content or ""
if not inner_voice_enabled:
retry_text = _strip_inner_voice_text(retry_text)
if max_chars > 0 and len(retry_text) > max_chars:
retry_text = retry_text[:max_chars]
if retry_text and not _is_exact_repeat(last_msg, retry_text):
llm_result = retry_result
reply_text = retry_text
else:
base_text = last_msg.content or reply_text
reply_text = f"我刚才说的是:{base_text}"
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=reply_text,
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
token_input=(llm_result.usage or {}).get("input_tokens"),
token_output=(llm_result.usage or {}).get("output_tokens"),
)
db.add(lover_msg)
# 更新会话与额度
session.last_message_at = datetime.utcnow()
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
db.add(session)
db.add(user_row)
db.flush()
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
return success_response(
ChatSendImageOut(
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
reply=reply_text,
inner_voice_enabled=bool(inner_voice_enabled),
usage=llm_result.usage,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
)
)
def _maybe_generate_summary(db: Session, session: ChatSession, upto_seq: int):
"""
若累计新增消息达到阈值则生成滚动摘要写入 nf_chat_summary
"""
threshold = settings.CHAT_SUMMARY_TRIGGER_MSGS or 0
token_threshold = settings.CHAT_SUMMARY_TRIGGER_TOKENS or 0
if threshold <= 0 and token_threshold <= 0:
return
last_summary = (
db.query(ChatSummary)
.filter(ChatSummary.session_id == session.id)
.order_by(ChatSummary.upto_seq.desc())
.first()
)
last_upto = last_summary.upto_seq if last_summary else 0
new_msgs = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id, ChatMessage.seq > last_upto, ChatMessage.seq <= upto_seq)
.order_by(ChatMessage.seq.asc())
.all()
)
if threshold > 0 and len(new_msgs) < threshold:
# 条数未达阈值
pass
else:
# 若条数阈值已满足或未设置,则继续
pass
# 粗略估算 token按英文 4 chars/token中文 1.5 chars/token 简化)
def estimate_tokens(text: str) -> int:
if not text:
return 0
# 假设每 2 字符 ~ 1 token中英文混合粗估
return max(1, len(text) // 2)
total_tokens = sum(estimate_tokens(m.content or "") for m in new_msgs)
if threshold > 0 and len(new_msgs) < threshold and (token_threshold <= 0 or total_tokens < token_threshold):
return
if token_threshold > 0 and total_tokens < token_threshold and (threshold <= 0 or len(new_msgs) < threshold):
return
# 组装摘要输入
lines = []
for m in new_msgs:
speaker = "用户" if m.role == "user" else "恋人" if m.role == "lover" else "系统"
lines.append(f"{speaker}: {m.content}")
prompt = [
{
"role": "system",
"content": "你是对话摘要助手请用简洁中文总结关键事件、情绪、约定与待办50~120字。",
},
{"role": "user", "content": "\n".join(lines)},
]
result = chat_completion(prompt, temperature=0.3, max_tokens=300)
summary = ChatSummary(
session_id=session.id,
upto_seq=upto_seq,
summary_text=result.content or "",
model=settings.LLM_MODEL or "qwen-flash",
token_input=(result.usage or {}).get("input_tokens"),
token_output=(result.usage or {}).get("output_tokens"),
created_at=datetime.utcnow(),
)
db.add(summary)
db.flush()
@router.post("/send", response_model=ApiResponse[ChatSendOut])
def send_message(
payload: ChatSendIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
# Lover 必须存在
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
# 用户额度检查
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_ensure_quota(user_row)
# 必须传入已有会话
session = _fetch_session(db, user.id, payload.session_id)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 禁止连发:最近一条如果是用户,则需等待回复
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.with_for_update()
.order_by(ChatMessage.seq.desc())
.first()
)
if last_msg and last_msg.role == "user":
raise HTTPException(status_code=400, detail="请等待恋人回复后再发送")
# 新序号
next_seq = (last_msg.seq if last_msg and last_msg.seq else 0) + 1
now = datetime.utcnow()
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="text",
content=payload.message,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
inner_voice_enabled = session.inner_voice_enabled if session.inner_voice_enabled is not None else settings.INNER_VOICE_DEFAULT
# 非 VIP 则强制关闭心声,避免历史开关残留
if inner_voice_enabled and not _is_vip(user_row):
inner_voice_enabled = False
session.inner_voice_enabled = False
db.add(session)
db.flush()
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
llm_result = chat_completion(
context_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
reply_text = llm_result.content or ""
if not inner_voice_enabled:
reply_text = _strip_inner_voice_text(reply_text)
# 字数上限控制0 表示不截断)
max_chars = settings.CHAT_REPLY_MAX_CHARS or 0
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
if _is_exact_repeat(last_msg, reply_text):
retry_messages = list(context_messages)
retry_messages.append(
{
"role": "system",
"content": "请避免与上一条恋人回复完全相同,换一种说法简要复述,保持亲密语气。",
}
)
retry_result = chat_completion(
retry_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
retry_text = retry_result.content or ""
if not inner_voice_enabled:
retry_text = _strip_inner_voice_text(retry_text)
if max_chars > 0 and len(retry_text) > max_chars:
retry_text = retry_text[:max_chars]
if retry_text and not _is_exact_repeat(last_msg, retry_text):
llm_result = retry_result
reply_text = retry_text
else:
base_text = last_msg.content or reply_text
reply_text = f"我刚才说的是:{base_text}"
if max_chars > 0 and len(reply_text) > max_chars:
reply_text = reply_text[:max_chars]
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=reply_text,
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
token_input=(llm_result.usage or {}).get("input_tokens"),
token_output=(llm_result.usage or {}).get("output_tokens"),
)
db.add(lover_msg)
# 更新会话与额度
session.last_message_at = datetime.utcnow()
user_row.chat_used_today = (user_row.chat_used_today or 0) + 1
db.add(session)
db.add(user_row)
db.flush()
# 触发滚动摘要
_maybe_generate_summary(db, session, upto_seq=lover_msg.seq)
return success_response(
ChatSendOut(
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
reply=reply_text,
inner_voice_enabled=bool(inner_voice_enabled),
usage=llm_result.usage,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
)
)
class InnerVoiceIn(BaseModel):
session_id: int
enabled: bool
@router.put("/inner-voice", response_model=ApiResponse[ChatConfigOut])
def update_inner_voice(
payload: InnerVoiceIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, payload.session_id)
# 如果会话不存在,且用户希望开启/关闭心声,则自动创建会话以便落开关
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到,请先完成创建流程")
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
# 仅 VIP 可开启心声
user_row = db.query(User).filter(User.id == user.id).first()
if payload.enabled:
if not user_row or not _is_vip(user_row):
raise HTTPException(status_code=403, detail="仅 VIP 用户可开启心声")
session.inner_voice_enabled = payload.enabled
db.add(session)
db.flush()
_reset_chat_quota(user_row) # refresh today count if needed
return success_response(
ChatConfigOut(
session_id=session.id,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=payload.enabled,
is_vip=_is_vip(user_row),
)
)
@router.get("/config", response_model=ApiResponse[ChatConfigOut])
def get_chat_config(
session_id: Optional[int] = Query(default=None),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
session = _fetch_session(db, user.id, session_id) if session_id else None
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
_reset_chat_quota(user_row)
inner_voice = (
session.inner_voice_enabled
if session and session.inner_voice_enabled is not None
else user_row.inner_voice_enabled
if user_row.inner_voice_enabled is not None
else settings.INNER_VOICE_DEFAULT
)
if inner_voice and not _is_vip(user_row):
inner_voice = False
if session:
session.inner_voice_enabled = False
db.add(session)
db.flush()
return success_response(
ChatConfigOut(
session_id=session.id if session else None,
chat_limit_daily=user_row.chat_limit_daily or settings.CHAT_LIMIT_DAILY,
chat_used_today=user_row.chat_used_today or 0,
inner_voice_enabled=bool(inner_voice),
is_vip=_is_vip(user_row),
)
)
# ==================== 消息编辑功能 ====================
class EditMessageIn(BaseModel):
new_content: str = Field(..., min_length=1, max_length=4000, description="新的消息内容")
regenerate_reply: bool = Field(default=True, description="是否重新生成AI回复")
class EditMessageOut(BaseModel):
message_id: int
new_content: str
regenerated: bool
new_reply: Optional[str] = None
@router.put("/messages/{message_id}/edit", response_model=ApiResponse[EditMessageOut])
def edit_message(
message_id: int,
payload: EditMessageIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
编辑历史消息并可选重新生成AI回复
"""
# 1. 查找消息
msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == message_id, ChatMessage.user_id == user.id)
.with_for_update()
.first()
)
if not msg:
raise HTTPException(status_code=404, detail="消息不存在")
# 允许编辑AI消息lover角色
if msg.role not in ["user", "lover"]:
raise HTTPException(status_code=400, detail="只能编辑用户或AI消息")
# 2. 保存原始内容(首次编辑时)
if not msg.is_edited:
msg.original_content = msg.content
# 3. 更新消息
msg.content = payload.new_content
msg.is_edited = True
msg.edited_at = datetime.utcnow()
db.add(msg)
db.flush()
# 4. 删除该消息之后的摘要(需要重新生成)
db.query(ChatSummary).filter(
ChatSummary.session_id == msg.session_id,
ChatSummary.upto_seq >= msg.seq
).delete()
db.flush()
new_reply = None
regenerated = False
# 5. 重新生成AI回复仅当编辑用户消息时
if payload.regenerate_reply and msg.role == "user":
# 找到紧接着的AI回复
next_lover_msg = (
db.query(ChatMessage)
.filter(
ChatMessage.session_id == msg.session_id,
ChatMessage.seq == msg.seq + 1,
ChatMessage.role == "lover"
)
.with_for_update()
.first()
)
if next_lover_msg:
# 获取会话和恋人信息
session = db.query(ChatSession).filter(ChatSession.id == msg.session_id).first()
lover = db.query(Lover).filter(Lover.id == msg.lover_id).first()
if session and lover:
# 构建上下文(使用修改后的消息)
inner_voice_enabled = session.inner_voice_enabled or False
context_messages = _build_context_messages(db, session, lover, user, inner_voice_enabled)
# 调用LLM重新生成
llm_result = chat_completion(
context_messages,
temperature=settings.LLM_TEMPERATURE,
max_tokens=settings.LLM_MAX_TOKENS,
seed=_pick_llm_seed(),
)
new_reply = llm_result.content or ""
if not inner_voice_enabled:
new_reply = _strip_inner_voice_text(new_reply)
# 更新AI回复
next_lover_msg.content = new_reply
next_lover_msg.is_edited = True
next_lover_msg.edited_at = datetime.utcnow()
next_lover_msg.token_input = (llm_result.usage or {}).get("input_tokens")
next_lover_msg.token_output = (llm_result.usage or {}).get("output_tokens")
db.add(next_lover_msg)
db.flush()
regenerated = True
# 触发摘要重新生成
_maybe_generate_summary(db, session, upto_seq=next_lover_msg.seq)
return success_response(
EditMessageOut(
message_id=msg.id,
new_content=msg.content,
regenerated=regenerated,
new_reply=new_reply,
)
)