Ai_GirlFriend/lover/routers/config.py

675 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Optional
from decimal import Decimal
import time
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from ..db import get_db
from ..models import VoiceLibrary, Lover, User, UserMoneyLog
from ..response import ApiResponse, success_response
from ..deps import get_current_user, AuthedUser
from sqlalchemy.exc import IntegrityError
router = APIRouter(prefix="/config", tags=["config"])
def _parse_owned_voices(raw: Optional[str]) -> set[int]:
owned: set[int] = set()
if not raw:
return owned
if isinstance(raw, list):
for v in raw:
try:
owned.add(int(v))
except Exception:
continue
return owned
if isinstance(raw, str):
try:
import json
parsed = json.loads(raw)
if isinstance(parsed, list):
for v in parsed:
try:
owned.add(int(v))
except Exception:
continue
return owned
except Exception:
for part in str(raw).split(","):
part = part.strip()
if part.isdigit():
owned.add(int(part))
return owned
return owned
def _serialize_owned_voices(ids: set[int]) -> list[int]:
return sorted(list(ids))
def _ensure_balance(user_row: User) -> Decimal:
try:
return Decimal(str(user_row.money or "0"))
except Exception:
return Decimal("0")
class VoiceOut(BaseModel):
id: int
name: str
gender: str
style_tag: Optional[str] = None
avatar_url: Optional[str] = None
sample_audio_url: Optional[str] = None
tts_model_id: Optional[str] = None
is_default: bool = False
voice_code: str
is_owned: bool
price_gold: int
model_config = ConfigDict(from_attributes=True)
class VoiceListResponse(BaseModel):
voices: List[VoiceOut]
default_voice_id: Optional[int] = None
selected_voice_id: Optional[int] = None
class VoiceMallItem(BaseModel):
id: int
name: str
gender: str
style_tag: Optional[str] = None
avatar_url: Optional[str] = None
sample_audio_url: Optional[str] = None
price_gold: int
model_config = ConfigDict(from_attributes=True)
class VoiceMallResponse(BaseModel):
voices: List[VoiceMallItem]
owned_voice_ids: List[int]
balance: float
class VoicePurchaseIn(BaseModel):
voice_id: int
class VoicePurchaseOut(BaseModel):
voice_id: int
balance: float
owned_voice_ids: List[int]
class VoiceAvailableResponse(BaseModel):
gender: str
voices: List[VoiceOut]
selected_voice_id: Optional[int] = None
@router.get("/voices", response_model=ApiResponse[VoiceListResponse])
def list_voices(
gender: Optional[str] = Query(default=None, pattern="^(male|female)$"),
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
query = db.query(VoiceLibrary)
if gender:
query = query.filter(VoiceLibrary.gender == gender)
voices = query.order_by(VoiceLibrary.id.asc()).all()
if not voices:
raise HTTPException(status_code=404, detail="未配置音色")
default_voice = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
.first()
) if gender else None
selected_voice_id = None
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if lover:
# 若前端传了 gender可校验匹配否则直接返回已选
if (not gender) or lover.gender == gender:
selected_voice_id = lover.voice_id
return success_response(
VoiceListResponse(
voices=voices,
default_voice_id=default_voice.id if default_voice else None,
selected_voice_id=selected_voice_id,
)
)
@router.get("/voices/mall", response_model=ApiResponse[VoiceMallResponse])
def list_paid_voices(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
金币商场:返回当前恋人性别的所有付费音色、已拥有列表与金币余额。
"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
balance = float(_ensure_balance(user_row))
voices = (
db.query(VoiceLibrary)
.filter(
VoiceLibrary.gender == lover.gender,
VoiceLibrary.price_gold > 0,
)
.order_by(VoiceLibrary.id.asc())
.all()
)
return success_response(
VoiceMallResponse(
voices=[
VoiceMallItem(
id=v.id,
name=v.name,
gender=v.gender,
style_tag=v.style_tag,
avatar_url=v.avatar_url,
sample_audio_url=v.sample_audio_url,
price_gold=v.price_gold or 0,
)
for v in voices
],
owned_voice_ids=_serialize_owned_voices(owned_ids),
balance=balance,
)
)
@router.post("/voices/purchase", response_model=ApiResponse[VoicePurchaseOut])
def purchase_voice(
payload: VoicePurchaseIn,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
voice = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.id == payload.voice_id, VoiceLibrary.gender == lover.gender)
.first()
)
if not voice:
raise HTTPException(status_code=404, detail="音色不存在或与恋人性别不匹配")
price = Decimal(voice.price_gold or 0)
if price <= 0:
raise HTTPException(status_code=400, detail="该音色不需要购买")
try:
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
except Exception:
user_row = None
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
if int(voice.id) in owned_ids:
raise HTTPException(status_code=400, detail="已拥有该音色,无需重复购买")
balance = _ensure_balance(user_row)
if balance < price:
raise HTTPException(status_code=400, detail="余额不足")
# 扣款并记录拥有(行锁下保证并发安全)
before_balance = balance
balance -= price
user_row.money = float(balance)
owned_ids.add(int(voice.id))
user_row.owned_voice_ids = _serialize_owned_voices(owned_ids)
db.add(user_row)
db.add(
UserMoneyLog(
user_id=user.id,
money=-price,
before=before_balance,
after=Decimal(user_row.money),
memo=f"购买音色:{voice.name}",
createtime=int(Decimal(time.time()).to_integral_value()),
)
)
try:
db.flush()
except IntegrityError:
db.rollback()
raise HTTPException(status_code=409, detail="购买请求冲突,请重试")
return success_response(
VoicePurchaseOut(
voice_id=voice.id,
balance=float(balance),
owned_voice_ids=_serialize_owned_voices(owned_ids),
)
)
@router.get("/voices/available", response_model=ApiResponse[VoiceAvailableResponse])
def list_available_voices_for_lover(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
返回当前恋人性别下可用的音色列表:已拥有的音色 + 免费音色。
不返回未拥有的付费音色,供“更换音色”页直接选择使用。
"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
gender = lover.gender
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
query = db.query(VoiceLibrary).filter(VoiceLibrary.gender == gender)
if owned_ids:
query = query.filter(
(VoiceLibrary.price_gold <= 0) | (VoiceLibrary.id.in_(owned_ids))
)
else:
query = query.filter(VoiceLibrary.price_gold <= 0)
voices = query.order_by(VoiceLibrary.id.asc()).all()
if not voices:
raise HTTPException(status_code=404, detail="未配置音色")
voices_out: List[VoiceOut] = []
for v in voices:
owned = int(v.id) in owned_ids or (v.price_gold or 0) <= 0
voices_out.append(
VoiceOut(
id=v.id,
name=v.name,
gender=v.gender,
style_tag=v.style_tag,
avatar_url=v.avatar_url,
sample_audio_url=v.sample_audio_url,
tts_model_id=v.tts_model_id,
is_default=bool(v.is_default),
voice_code=v.voice_code,
is_owned=owned,
price_gold=v.price_gold or 0,
)
)
return success_response(
VoiceAvailableResponse(
gender=gender,
voices=voices_out,
selected_voice_id=lover.voice_id if lover.voice_id else None,
),
msg="获取可用音色成功",
)
# ===== 音色克隆相关 =====
class VoiceCloneRequest(BaseModel):
audio_url: str
voice_name: str # 用户输入的显示名称(可以是中文)
gender: str # male/female
model_config = ConfigDict(from_attributes=True)
class VoiceCloneResponse(BaseModel):
voice_id: str
status: str
message: str
model_config = ConfigDict(from_attributes=True)
class VoiceCloneStatusResponse(BaseModel):
voice_id: str
status: str # PENDING, OK, UNDEPLOYED, FAILED
voice_library_id: Optional[int] = None
model_config = ConfigDict(from_attributes=True)
@router.post("/voices/clone", response_model=ApiResponse[VoiceCloneResponse])
def clone_voice(
payload: VoiceCloneRequest,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
克隆音色:用户上传音频文件,系统调用 CosyVoice 克隆音色
用户可以输入中文名称,系统会自动生成英文哈希码用于 API 调用
"""
from ..cosyvoice_clone import create_voice_from_url
import hashlib
import time
# 验证音色名称长度
if len(payload.voice_name) > 20:
raise HTTPException(status_code=400, detail="音色名称不能超过20个字符")
# 验证性别
if payload.gender not in ["male", "female"]:
raise HTTPException(status_code=400, detail="性别必须是 male 或 female")
try:
# 生成英文哈希码作为 prefixCosyVoice API 要求最多10字符
# 使用时间戳 + 用户ID 生成唯一哈希
hash_input = f"{user.id}_{int(time.time())}"
hash_code = hashlib.md5(hash_input.encode()).hexdigest()[:6] # 取6位
api_prefix = f"v{hash_code}" # v + 6位哈希 = 7字符符合限制
# 调用克隆服务(使用英文哈希码)
voice_id = create_voice_from_url(
audio_url=payload.audio_url,
prefix=api_prefix,
target_model="cosyvoice-v2"
)
# 将中文显示名称和 voice_id 的映射关系临时存储
# 后续保存到数据库时会用到
# 这里可以使用 Redis 或者直接在保存时传递
return success_response(
VoiceCloneResponse(
voice_id=voice_id,
status="PENDING",
message="音色克隆任务已创建,请稍后查询状态"
)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"克隆失败: {str(e)}")
@router.get("/voices/clone/{voice_id}/status", response_model=ApiResponse[VoiceCloneStatusResponse])
def get_clone_status(
voice_id: str,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
查询克隆音色的状态
"""
from ..cosyvoice_clone import query_voice
try:
info = query_voice(voice_id)
status = info.get("status", "UNKNOWN")
# 如果状态是 OK检查是否已保存到数据库
voice_library_id = None
if status == "OK":
existing = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.voice_code == voice_id)
.first()
)
if existing:
voice_library_id = existing.id
return success_response(
VoiceCloneStatusResponse(
voice_id=voice_id,
status=status,
voice_library_id=voice_library_id
)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
class VoiceCloneSaveRequest(BaseModel):
display_name: Optional[str] = None # 用户输入的显示名称(中文)
model_config = ConfigDict(from_attributes=True)
@router.post("/voices/clone/{voice_id}/save", response_model=ApiResponse[dict])
def save_cloned_voice(
voice_id: str,
payload: VoiceCloneSaveRequest,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
将克隆成功的音色保存到音色库
"""
from ..cosyvoice_clone import query_voice
try:
# 查询音色状态
info = query_voice(voice_id)
status = info.get("status")
if status != "OK":
raise HTTPException(status_code=400, detail=f"音色状态为 {status},无法保存")
# 检查是否已存在
existing = (
db.query(VoiceLibrary)
.filter(VoiceLibrary.voice_code == voice_id)
.first()
)
if existing:
return success_response({"voice_library_id": existing.id, "message": "音色已存在"})
# 获取音色信息
api_voice_name = info.get("name", "克隆音色") # API 返回的英文哈希名称
# 使用用户输入的显示名称,如果没有则使用 API 名称
display_name = payload.display_name if payload.display_name else api_voice_name
# 获取用户的恋人信息以确定性别
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
gender = lover.gender if lover else "female"
# 保存到数据库name 字段保存用户输入的中文名称)
new_voice = VoiceLibrary(
name=display_name, # 保存中文显示名称
gender=gender,
style_tag="克隆音色",
avatar_url=None,
sample_audio_url=None,
tts_model_id="cosyvoice-v2",
is_default=False,
voice_code=voice_id, # voice_code 是 API 返回的英文哈希码
is_owned=True,
price_gold=0
)
db.add(new_voice)
db.flush()
# 将克隆的音色添加到用户的拥有列表中
user_row = db.query(User).filter(User.id == user.id).first()
if user_row:
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
owned_ids.add(new_voice.id)
user_row.owned_voice_ids = ",".join(map(str, sorted(owned_ids)))
db.add(user_row)
# 自动设置为恋人的当前音色
if lover:
lover.voice_id = new_voice.id
db.add(lover)
db.flush()
return success_response({
"voice_library_id": new_voice.id,
"message": "音色保存成功并已设置为当前音色",
"display_name": display_name
})
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
# ===== 邀请码相关 =====
class InviteInfoResponse(BaseModel):
invite_code: str
invite_count: int
invite_reward_total: float
invite_url: str
model_config = ConfigDict(from_attributes=True)
class InviteRewardRequest(BaseModel):
invite_code: str
model_config = ConfigDict(from_attributes=True)
def _generate_invite_code() -> str:
"""生成6位邀请码不含易混淆字符"""
import random
import string
chars = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # 去掉 0O1I 等易混淆字符
return ''.join(random.choice(chars) for _ in range(6))
@router.get("/invite/info", response_model=ApiResponse[InviteInfoResponse])
def get_invite_info(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
获取用户的邀请信息
"""
user_row = db.query(User).filter(User.id == user.id).first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# 如果没有邀请码,生成一个
if not user_row.invite_code:
while True:
code = _generate_invite_code()
# 检查是否重复
existing = db.query(User).filter(User.invite_code == code).first()
if not existing:
user_row.invite_code = code
db.add(user_row)
db.flush()
break
# 生成邀请链接(这里使用简单的格式,实际可以是 H5 页面链接)
invite_url = f"https://your-domain.com/register?invite={user_row.invite_code}"
return success_response(
InviteInfoResponse(
invite_code=user_row.invite_code,
invite_count=user_row.invite_count or 0,
invite_reward_total=float(user_row.invite_reward_total or 0),
invite_url=invite_url
)
)
@router.post("/invite/apply", response_model=ApiResponse[dict])
def apply_invite_code(
payload: InviteRewardRequest,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
新用户使用邀请码(注册后调用)
"""
user_row = db.query(User).filter(User.id == user.id).with_for_update().first()
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# 检查是否已经使用过邀请码
if user_row.invited_by:
raise HTTPException(status_code=400, detail="您已经使用过邀请码")
# 不能使用自己的邀请码
if user_row.invite_code == payload.invite_code:
raise HTTPException(status_code=400, detail="不能使用自己的邀请码")
# 查找邀请人
inviter = db.query(User).filter(User.invite_code == payload.invite_code).with_for_update().first()
if not inviter:
raise HTTPException(status_code=404, detail="邀请码不存在")
# 奖励金额配置
INVITER_REWARD = 10.00 # 邀请人获得10金币
INVITEE_REWARD = 5.00 # 被邀请人获得5金币
# 给邀请人发放奖励
inviter.money = float(inviter.money or 0) + INVITER_REWARD
inviter.invite_count = (inviter.invite_count or 0) + 1
inviter.invite_reward_total = float(inviter.invite_reward_total or 0) + INVITER_REWARD
db.add(inviter)
# 记录邀请人的金币日志
db.add(
UserMoneyLog(
user_id=inviter.id,
money=Decimal(str(INVITER_REWARD)),
before=Decimal(str(float(inviter.money) - INVITER_REWARD)),
after=Decimal(str(inviter.money)),
memo=f"邀请新用户奖励",
createtime=int(time.time()),
)
)
# 给被邀请人发放奖励
user_row.money = float(user_row.money or 0) + INVITEE_REWARD
user_row.invited_by = payload.invite_code
db.add(user_row)
# 记录被邀请人的金币日志
db.add(
UserMoneyLog(
user_id=user.id,
money=Decimal(str(INVITEE_REWARD)),
before=Decimal(str(float(user_row.money) - INVITEE_REWARD)),
after=Decimal(str(user_row.money)),
memo=f"使用邀请码奖励",
createtime=int(time.time()),
)
)
try:
db.flush()
except IntegrityError:
db.rollback()
raise HTTPException(status_code=409, detail="操作失败,请重试")
return success_response({
"message": f"邀请码使用成功!您获得了{INVITEE_REWARD}金币",
"reward": INVITEE_REWARD,
"balance": float(user_row.money)
})