495 lines
15 KiB
Python
495 lines
15 KiB
Python
from typing import List, Optional
|
||
from decimal import Decimal
|
||
import time
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from pydantic import BaseModel, ConfigDict
|
||
from sqlalchemy.orm import Session
|
||
|
||
from ..db import get_db
|
||
from ..models import VoiceLibrary, Lover, User, UserMoneyLog
|
||
from ..response import ApiResponse, success_response
|
||
from ..deps import get_current_user, AuthedUser
|
||
from sqlalchemy.exc import IntegrityError
|
||
|
||
router = APIRouter(prefix="/config", tags=["config"])
|
||
|
||
|
||
def _parse_owned_voices(raw: Optional[str]) -> set[int]:
|
||
owned: set[int] = set()
|
||
if not raw:
|
||
return owned
|
||
if isinstance(raw, list):
|
||
for v in raw:
|
||
try:
|
||
owned.add(int(v))
|
||
except Exception:
|
||
continue
|
||
return owned
|
||
if isinstance(raw, str):
|
||
try:
|
||
import json
|
||
|
||
parsed = json.loads(raw)
|
||
if isinstance(parsed, list):
|
||
for v in parsed:
|
||
try:
|
||
owned.add(int(v))
|
||
except Exception:
|
||
continue
|
||
return owned
|
||
except Exception:
|
||
for part in str(raw).split(","):
|
||
part = part.strip()
|
||
if part.isdigit():
|
||
owned.add(int(part))
|
||
return owned
|
||
return owned
|
||
|
||
|
||
def _serialize_owned_voices(ids: set[int]) -> list[int]:
|
||
return sorted(list(ids))
|
||
|
||
|
||
def _ensure_balance(user_row: User) -> Decimal:
|
||
try:
|
||
return Decimal(str(user_row.money or "0"))
|
||
except Exception:
|
||
return Decimal("0")
|
||
|
||
|
||
class VoiceOut(BaseModel):
|
||
id: int
|
||
name: str
|
||
gender: str
|
||
style_tag: Optional[str] = None
|
||
avatar_url: Optional[str] = None
|
||
sample_audio_url: Optional[str] = None
|
||
tts_model_id: Optional[str] = None
|
||
is_default: bool = False
|
||
voice_code: str
|
||
is_owned: bool
|
||
price_gold: int
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
class VoiceListResponse(BaseModel):
|
||
voices: List[VoiceOut]
|
||
default_voice_id: Optional[int] = None
|
||
selected_voice_id: Optional[int] = None
|
||
|
||
|
||
class VoiceMallItem(BaseModel):
|
||
id: int
|
||
name: str
|
||
gender: str
|
||
style_tag: Optional[str] = None
|
||
avatar_url: Optional[str] = None
|
||
sample_audio_url: Optional[str] = None
|
||
price_gold: int
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
class VoiceMallResponse(BaseModel):
|
||
voices: List[VoiceMallItem]
|
||
owned_voice_ids: List[int]
|
||
balance: float
|
||
|
||
|
||
class VoicePurchaseIn(BaseModel):
|
||
voice_id: int
|
||
|
||
|
||
class VoicePurchaseOut(BaseModel):
|
||
voice_id: int
|
||
balance: float
|
||
owned_voice_ids: List[int]
|
||
|
||
|
||
class VoiceAvailableResponse(BaseModel):
|
||
gender: str
|
||
voices: List[VoiceOut]
|
||
selected_voice_id: Optional[int] = None
|
||
|
||
|
||
@router.get("/voices", response_model=ApiResponse[VoiceListResponse])
|
||
def list_voices(
|
||
gender: Optional[str] = Query(default=None, pattern="^(male|female)$"),
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
query = db.query(VoiceLibrary)
|
||
if gender:
|
||
query = query.filter(VoiceLibrary.gender == gender)
|
||
voices = query.order_by(VoiceLibrary.id.asc()).all()
|
||
if not voices:
|
||
raise HTTPException(status_code=404, detail="未配置音色")
|
||
|
||
default_voice = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.gender == gender, VoiceLibrary.is_default.is_(True))
|
||
.first()
|
||
) if gender else None
|
||
|
||
selected_voice_id = None
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if lover:
|
||
# 若前端传了 gender,可校验匹配;否则直接返回已选
|
||
if (not gender) or lover.gender == gender:
|
||
selected_voice_id = lover.voice_id
|
||
|
||
return success_response(
|
||
VoiceListResponse(
|
||
voices=voices,
|
||
default_voice_id=default_voice.id if default_voice else None,
|
||
selected_voice_id=selected_voice_id,
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/voices/mall", response_model=ApiResponse[VoiceMallResponse])
|
||
def list_paid_voices(
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
金币商场:返回当前恋人性别的所有付费音色、已拥有列表与金币余额。
|
||
"""
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
|
||
balance = float(_ensure_balance(user_row))
|
||
|
||
voices = (
|
||
db.query(VoiceLibrary)
|
||
.filter(
|
||
VoiceLibrary.gender == lover.gender,
|
||
VoiceLibrary.price_gold > 0,
|
||
)
|
||
.order_by(VoiceLibrary.id.asc())
|
||
.all()
|
||
)
|
||
return success_response(
|
||
VoiceMallResponse(
|
||
voices=[
|
||
VoiceMallItem(
|
||
id=v.id,
|
||
name=v.name,
|
||
gender=v.gender,
|
||
style_tag=v.style_tag,
|
||
avatar_url=v.avatar_url,
|
||
sample_audio_url=v.sample_audio_url,
|
||
price_gold=v.price_gold or 0,
|
||
)
|
||
for v in voices
|
||
],
|
||
owned_voice_ids=_serialize_owned_voices(owned_ids),
|
||
balance=balance,
|
||
)
|
||
)
|
||
|
||
|
||
@router.post("/voices/purchase", response_model=ApiResponse[VoicePurchaseOut])
|
||
def purchase_voice(
|
||
payload: VoicePurchaseIn,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
|
||
voice = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.id == payload.voice_id, VoiceLibrary.gender == lover.gender)
|
||
.first()
|
||
)
|
||
if not voice:
|
||
raise HTTPException(status_code=404, detail="音色不存在或与恋人性别不匹配")
|
||
price = Decimal(voice.price_gold or 0)
|
||
if price <= 0:
|
||
raise HTTPException(status_code=400, detail="该音色不需要购买")
|
||
|
||
try:
|
||
user_row = (
|
||
db.query(User)
|
||
.filter(User.id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
except Exception:
|
||
user_row = None
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
|
||
if int(voice.id) in owned_ids:
|
||
raise HTTPException(status_code=400, detail="已拥有该音色,无需重复购买")
|
||
|
||
balance = _ensure_balance(user_row)
|
||
if balance < price:
|
||
raise HTTPException(status_code=400, detail="余额不足")
|
||
|
||
# 扣款并记录拥有(行锁下保证并发安全)
|
||
before_balance = balance
|
||
balance -= price
|
||
user_row.money = float(balance)
|
||
owned_ids.add(int(voice.id))
|
||
user_row.owned_voice_ids = _serialize_owned_voices(owned_ids)
|
||
db.add(user_row)
|
||
db.add(
|
||
UserMoneyLog(
|
||
user_id=user.id,
|
||
money=-price,
|
||
before=before_balance,
|
||
after=Decimal(user_row.money),
|
||
memo=f"购买音色:{voice.name}",
|
||
createtime=int(Decimal(time.time()).to_integral_value()),
|
||
)
|
||
)
|
||
try:
|
||
db.flush()
|
||
except IntegrityError:
|
||
db.rollback()
|
||
raise HTTPException(status_code=409, detail="购买请求冲突,请重试")
|
||
|
||
return success_response(
|
||
VoicePurchaseOut(
|
||
voice_id=voice.id,
|
||
balance=float(balance),
|
||
owned_voice_ids=_serialize_owned_voices(owned_ids),
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/voices/available", response_model=ApiResponse[VoiceAvailableResponse])
|
||
def list_available_voices_for_lover(
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
返回当前恋人性别下可用的音色列表:已拥有的音色 + 免费音色。
|
||
不返回未拥有的付费音色,供“更换音色”页直接选择使用。
|
||
"""
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
gender = lover.gender
|
||
|
||
user_row = db.query(User).filter(User.id == user.id).first()
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
owned_ids = _parse_owned_voices(user_row.owned_voice_ids)
|
||
|
||
query = db.query(VoiceLibrary).filter(VoiceLibrary.gender == gender)
|
||
if owned_ids:
|
||
query = query.filter(
|
||
(VoiceLibrary.price_gold <= 0) | (VoiceLibrary.id.in_(owned_ids))
|
||
)
|
||
else:
|
||
query = query.filter(VoiceLibrary.price_gold <= 0)
|
||
voices = query.order_by(VoiceLibrary.id.asc()).all()
|
||
if not voices:
|
||
raise HTTPException(status_code=404, detail="未配置音色")
|
||
|
||
voices_out: List[VoiceOut] = []
|
||
for v in voices:
|
||
owned = int(v.id) in owned_ids or (v.price_gold or 0) <= 0
|
||
voices_out.append(
|
||
VoiceOut(
|
||
id=v.id,
|
||
name=v.name,
|
||
gender=v.gender,
|
||
style_tag=v.style_tag,
|
||
avatar_url=v.avatar_url,
|
||
sample_audio_url=v.sample_audio_url,
|
||
tts_model_id=v.tts_model_id,
|
||
is_default=bool(v.is_default),
|
||
voice_code=v.voice_code,
|
||
is_owned=owned,
|
||
price_gold=v.price_gold or 0,
|
||
)
|
||
)
|
||
|
||
return success_response(
|
||
VoiceAvailableResponse(
|
||
gender=gender,
|
||
voices=voices_out,
|
||
selected_voice_id=lover.voice_id if lover.voice_id else None,
|
||
),
|
||
msg="获取可用音色成功",
|
||
)
|
||
|
||
|
||
# ===== 音色克隆相关 =====
|
||
|
||
class VoiceCloneRequest(BaseModel):
|
||
audio_url: str
|
||
voice_name: str
|
||
gender: str # male/female
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
class VoiceCloneResponse(BaseModel):
|
||
voice_id: str
|
||
status: str
|
||
message: str
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
class VoiceCloneStatusResponse(BaseModel):
|
||
voice_id: str
|
||
status: str # PENDING, OK, UNDEPLOYED, FAILED
|
||
voice_library_id: Optional[int] = None
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
@router.post("/voices/clone", response_model=ApiResponse[VoiceCloneResponse])
|
||
def clone_voice(
|
||
payload: VoiceCloneRequest,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
克隆音色:用户上传音频文件,系统调用 CosyVoice 克隆音色
|
||
"""
|
||
from ..cosyvoice_clone import create_voice_from_url
|
||
|
||
# 验证音色名称长度(CosyVoice 限制 prefix <= 10 字符)
|
||
if len(payload.voice_name) > 10:
|
||
raise HTTPException(status_code=400, detail="音色名称不能超过10个字符")
|
||
|
||
# 验证性别
|
||
if payload.gender not in ["male", "female"]:
|
||
raise HTTPException(status_code=400, detail="性别必须是 male 或 female")
|
||
|
||
try:
|
||
# 调用克隆服务
|
||
voice_id = create_voice_from_url(
|
||
audio_url=payload.audio_url,
|
||
prefix=payload.voice_name,
|
||
target_model="cosyvoice-v2"
|
||
)
|
||
|
||
return success_response(
|
||
VoiceCloneResponse(
|
||
voice_id=voice_id,
|
||
status="PENDING",
|
||
message="音色克隆任务已创建,请稍后查询状态"
|
||
)
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"克隆失败: {str(e)}")
|
||
|
||
|
||
@router.get("/voices/clone/{voice_id}/status", response_model=ApiResponse[VoiceCloneStatusResponse])
|
||
def get_clone_status(
|
||
voice_id: str,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
查询克隆音色的状态
|
||
"""
|
||
from ..cosyvoice_clone import query_voice
|
||
|
||
try:
|
||
info = query_voice(voice_id)
|
||
status = info.get("status", "UNKNOWN")
|
||
|
||
# 如果状态是 OK,检查是否已保存到数据库
|
||
voice_library_id = None
|
||
if status == "OK":
|
||
existing = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.voice_code == voice_id)
|
||
.first()
|
||
)
|
||
if existing:
|
||
voice_library_id = existing.id
|
||
|
||
return success_response(
|
||
VoiceCloneStatusResponse(
|
||
voice_id=voice_id,
|
||
status=status,
|
||
voice_library_id=voice_library_id
|
||
)
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
||
|
||
|
||
@router.post("/voices/clone/{voice_id}/save", response_model=ApiResponse[dict])
|
||
def save_cloned_voice(
|
||
voice_id: str,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""
|
||
将克隆成功的音色保存到音色库
|
||
"""
|
||
from ..cosyvoice_clone import query_voice
|
||
|
||
try:
|
||
# 查询音色状态
|
||
info = query_voice(voice_id)
|
||
status = info.get("status")
|
||
|
||
if status != "OK":
|
||
raise HTTPException(status_code=400, detail=f"音色状态为 {status},无法保存")
|
||
|
||
# 检查是否已存在
|
||
existing = (
|
||
db.query(VoiceLibrary)
|
||
.filter(VoiceLibrary.voice_code == voice_id)
|
||
.first()
|
||
)
|
||
if existing:
|
||
return success_response({"voice_library_id": existing.id, "message": "音色已存在"})
|
||
|
||
# 获取音色信息
|
||
voice_name = info.get("name", "克隆音色")
|
||
|
||
# 获取用户的恋人信息以确定性别
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
gender = lover.gender if lover else "female"
|
||
|
||
# 保存到数据库
|
||
new_voice = VoiceLibrary(
|
||
name=voice_name,
|
||
gender=gender,
|
||
style_tag="克隆音色",
|
||
avatar_url=None,
|
||
sample_audio_url=None,
|
||
tts_model_id="cosyvoice-v2",
|
||
is_default=False,
|
||
voice_code=voice_id,
|
||
is_owned=True,
|
||
price_gold=0
|
||
)
|
||
db.add(new_voice)
|
||
db.flush()
|
||
|
||
return success_response({
|
||
"voice_library_id": new_voice.id,
|
||
"message": "音色保存成功"
|
||
})
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"保存失败: {str(e)}")
|