Ai_GirlFriend/lover/routers/dance.py

1489 lines
53 KiB
Python
Raw Normal View History

2026-01-31 19:15:41 +08:00
import hashlib
import os
import random
import shutil
2026-01-31 19:15:41 +08:00
import subprocess
import tempfile
import time
from datetime import datetime
from typing import Optional
import oss2
import requests
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from ..config import settings
from ..db import SessionLocal, get_db
from ..deps import AuthedUser, get_current_user
from ..models import (
ChatMessage,
ChatSession,
GenerationTask,
Lover,
SongLibrary,
User,
)
from ..response import ApiResponse, success_response
try:
import imageio_ffmpeg # type: ignore
except Exception: # pragma: no cover
imageio_ffmpeg = None
2026-01-31 19:15:41 +08:00
router = APIRouter(prefix="/dance", tags=["dance"])
def _ffmpeg_bin() -> str:
found = shutil.which("ffmpeg")
if found:
return found
if imageio_ffmpeg is not None:
try:
return imageio_ffmpeg.get_ffmpeg_exe()
except Exception:
pass
return "ffmpeg"
def _ffprobe_bin() -> str:
found = shutil.which("ffprobe")
if found:
return found
if imageio_ffmpeg is not None:
try:
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe.exe")
if os.path.exists(candidate):
return candidate
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe")
if os.path.exists(candidate):
return candidate
except Exception:
pass
return "ffprobe"
2026-01-31 19:15:41 +08:00
DANCE_TARGET_DURATION_SEC = 10
2026-02-02 20:08:28 +08:00
@router.get("/history", response_model=ApiResponse[list[dict]])
def get_dance_history(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
page: int = 1,
size: int = 20,
):
"""
获取用户的跳舞视频历史记录
"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
offset = (page - 1) * size
tasks = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.lover_id == lover.id,
GenerationTask.task_type == "video",
GenerationTask.payload["prompt"].as_string().isnot(None),
2026-02-02 20:08:28 +08:00
GenerationTask.status == "succeeded",
GenerationTask.result_url.isnot(None),
)
.order_by(GenerationTask.id.desc())
.offset(offset)
.limit(size)
.all()
)
result: list[dict] = []
for task in tasks:
payload = task.payload or {}
result.append(
{
"id": task.id,
"prompt": payload.get("prompt") or "",
"video_url": _cdnize(task.result_url) or "",
"created_at": task.created_at.isoformat() if task.created_at else None,
}
)
return success_response(result, msg="获取成功")
def _download_binary(url: str) -> bytes:
try:
resp = requests.get(url, timeout=30)
except Exception as exc:
raise HTTPException(status_code=502, detail="文件下载失败") from exc
if resp.status_code != 200:
raise HTTPException(status_code=502, detail="文件下载失败")
return resp.content
def _retry_finalize_dance_task(task_id: int) -> None:
"""任务失败但 DashScope 端已成功时,尝试重新下载视频并重新上传(自愈/手动重试共用)。"""
try:
with SessionLocal() as db:
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task:
return
payload = task.payload or {}
dash_id = payload.get("dashscope_task_id")
if not dash_id:
return
status, dash_video_url = _fetch_dashscope_status(str(dash_id))
if status != "SUCCEEDED" or not dash_video_url:
return
# 重新生成:下载 dashscope 视频 -> 随机 BGM -> 合成 -> 上传
bgm_song = _pick_random_bgm(db)
bgm_audio_url_raw = bgm_song.audio_url
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
dash_video_url,
bgm_audio_url,
DANCE_TARGET_DURATION_SEC,
)
object_name = f"lover/{task.lover_id}/dance/{int(time.time())}_retry.mp4"
oss_url = _upload_to_oss(merged_bytes, object_name)
task.status = "succeeded"
task.result_url = oss_url
task.error_msg = None
task.payload = {
**payload,
"dashscope_video_url": dash_video_url,
"bgm_song_id": bgm_song.id,
"bgm_audio_url": bgm_audio_url,
"bgm_audio_url_raw": bgm_audio_url_raw,
"bgm_start_sec": bgm_meta.get("bgm_start_sec"),
"bgm_duration": DANCE_TARGET_DURATION_SEC,
"retry_finalized": True,
}
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
except Exception:
return
@router.post("/retry/{task_id}", response_model=ApiResponse[dict])
def retry_dance_task(
task_id: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""手动重试:用于 DashScope 端已成功但本地下载/合成/上传失败导致任务失败的情况。"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
task = (
db.query(GenerationTask)
.filter(
GenerationTask.id == task_id,
GenerationTask.user_id == user.id,
GenerationTask.lover_id == lover.id,
GenerationTask.task_type == "video",
GenerationTask.payload["prompt"].as_string().isnot(None),
)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
payload = task.payload or {}
dash_id = payload.get("dashscope_task_id")
if not dash_id:
raise HTTPException(status_code=400, detail="任务缺少 dashscope_task_id无法重试")
task.payload = {**payload, "manual_retry": True}
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
background_tasks.add_task(_retry_finalize_dance_task, int(task.id))
return success_response({"task_id": int(task.id)}, msg="已触发重试")
@router.get("/history/all", response_model=ApiResponse[list[dict]])
def get_dance_history_all(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
page: int = 1,
size: int = 20,
):
"""获取用户的跳舞视频全部历史记录(成功+失败+进行中)。"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
offset = (page - 1) * size
tasks = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.lover_id == lover.id,
GenerationTask.task_type == "video",
GenerationTask.payload["prompt"].as_string().isnot(None),
)
.order_by(GenerationTask.id.desc())
.offset(offset)
.limit(size)
.all()
)
result: list[dict] = []
for task in tasks:
payload = task.payload or {}
url = task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or ""
result.append(
{
"id": int(task.id),
"prompt": payload.get("prompt") or "",
"status": task.status,
"video_url": _cdnize(url) or "",
"error_msg": task.error_msg,
"created_at": task.created_at.isoformat() if task.created_at else None,
}
)
return success_response(result, msg="获取成功")
2026-01-31 19:15:41 +08:00
class DanceGenerateIn(BaseModel):
prompt: str = Field(..., min_length=2, max_length=400, description="用户希望跳的舞/动作描述")
class DanceTaskStatusOut(BaseModel):
generation_task_id: int
status: str = Field(..., description="pending|running|succeeded|failed")
dashscope_task_id: str = ""
video_url: str = ""
session_id: int = 0
user_message_id: int = 0
lover_message_id: int = 0
error_msg: Optional[str] = None
@router.get("/current", response_model=ApiResponse[Optional[DanceTaskStatusOut]])
def get_current_dance_task(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
task = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
GenerationTask.status.in_(["pending", "running"]),
GenerationTask.payload["prompt"].as_string().isnot(None),
)
.order_by(GenerationTask.id.desc())
.first()
)
if not task:
return success_response(None, msg="暂无进行中的任务")
payload = task.payload or {}
return success_response(
DanceTaskStatusOut(
generation_task_id=task.id,
status=task.status,
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
video_url=task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or "",
session_id=int(payload.get("session_id") or 0),
user_message_id=int(payload.get("user_message_id") or 0),
lover_message_id=int(payload.get("lover_message_id") or 0),
error_msg=task.error_msg,
),
msg="获取成功",
)
2026-01-31 19:15:41 +08:00
def _upload_to_oss(file_bytes: bytes, object_name: str) -> str:
if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET:
raise HTTPException(status_code=500, detail="未配置 OSS Key")
if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT:
raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint")
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
bucket.put_object(object_name, file_bytes)
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn:
return f"{cdn.rstrip('/')}/{object_name}"
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
def _is_own_oss_url(url: str) -> bool:
if not url:
return False
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn and url.startswith(cdn.rstrip("/")):
return True
bucket = settings.ALIYUN_OSS_BUCKET_NAME
endpoint = (settings.ALIYUN_OSS_ENDPOINT or "").rstrip("/")
if bucket and endpoint:
domain = endpoint.replace("https://", "").replace("http://", "")
return url.startswith(f"https://{bucket}.{domain}/")
return False
def _cdnize(url: Optional[str]) -> Optional[str]:
"""
将相对路径补全为可访问 URL优先使用 CDN其次 bucket+endpoint最后兜底固定域名
"""
if not url:
return url
cleaned = url.strip()
if cleaned.startswith("http://") or cleaned.startswith("https://"):
return cleaned
cleaned = cleaned.lstrip("/")
if settings.ALIYUN_OSS_CDN_DOMAIN:
return f"{settings.ALIYUN_OSS_CDN_DOMAIN.rstrip('/')}/{cleaned}"
if settings.ALIYUN_OSS_BUCKET_NAME and settings.ALIYUN_OSS_ENDPOINT:
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/").replace("https://", "").replace("http://", "")
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint}/{cleaned}"
return f"https://nvlovers.oss-cn-qingdao.aliyuncs.com/{cleaned}"
def _shorten_url(url: str, max_len: int = 160) -> str:
if len(url) <= max_len:
return url
return f"{url[: max_len - 3]}..."
def _submit_video_task(prompt: str, image_url: str) -> str:
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
resolution = settings.VIDEO_GEN_RESOLUTION or "480P"
duration = settings.VIDEO_GEN_DURATION or 5
payload = {
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
"input": {
"prompt": prompt,
"img_url": image_url,
},
"parameters": {
"resolution": resolution,
"duration": duration,
"watermark": False,
"prompt_extend": True,
},
}
headers = {
"X-DashScope-Async": "enable",
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
"Content-Type": "application/json",
}
try:
resp = requests.post(
"https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis",
headers=headers,
json=payload,
timeout=10,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="调用视频生成接口失败") from exc
if resp.status_code != 200:
msg = resp.text
try:
msg = resp.json().get("message") or msg
except Exception:
pass
raise HTTPException(status_code=502, detail=f"视频任务提交失败: {msg}")
try:
data = resp.json()
except Exception as exc:
raise HTTPException(status_code=502, detail="视频任务返回解析失败") from exc
task_id = (
data.get("output", {}).get("task_id")
or data.get("task_id")
or data.get("output", {}).get("id")
)
if not task_id:
raise HTTPException(status_code=502, detail="视频任务未返回 task_id")
return str(task_id)
def _poll_video_url(task_id: str) -> str:
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
deadline = time.time() + 120
while time.time() < deadline:
time.sleep(3)
try:
resp = requests.get(query_url, headers=headers, timeout=8)
except Exception:
continue
if resp.status_code != 200:
continue
try:
data = resp.json()
except Exception:
continue
output = data.get("output") or {}
status_str = str(
output.get("task_status")
or data.get("task_status")
or data.get("status")
or ""
).upper()
if status_str == "SUCCEEDED":
url = output.get("video_url") or data.get("video_url")
if not url:
raise HTTPException(status_code=502, detail="视频生成成功但未返回结果 URL")
return url
if status_str == "FAILED":
msg = output.get("message") or data.get("message") or "生成失败"
raise HTTPException(status_code=502, detail=f"视频生成失败: {msg}")
raise HTTPException(status_code=504, detail="视频生成超时,请稍后重试")
def _fetch_dashscope_status(task_id: str) -> tuple[str, Optional[str]]:
"""
直接查询 dashscope 任务状态
返回 (status, video_url|None)status SUCCEEDED/FAILED/其他
"""
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
try:
resp = requests.get(query_url, headers=headers, timeout=8)
except Exception:
return "UNKNOWN", None
if resp.status_code != 200:
return "UNKNOWN", None
try:
data = resp.json()
except Exception:
return "UNKNOWN", None
output = data.get("output") or {}
status_str = str(
output.get("task_status")
or data.get("task_status")
or data.get("status")
or ""
).upper()
video_url = output.get("video_url") or data.get("video_url")
return status_str, video_url
def _download_to_path(url: str, target_path: str, label: str):
if not url:
raise HTTPException(status_code=502, detail=f"{label}下载失败: 空URL")
try:
resp = requests.get(url, stream=True, timeout=30)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"{label}下载失败: {_shorten_url(url)}") from exc
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"{label}下载失败({resp.status_code}): {_shorten_url(url)}")
try:
with open(target_path, "wb") as file_handle:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if chunk:
file_handle.write(chunk)
finally:
resp.close()
def _probe_media_duration(path: str) -> Optional[float]:
command = [
_ffprobe_bin(),
2026-01-31 19:15:41 +08:00
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path,
]
try:
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError:
return None
except subprocess.CalledProcessError:
return None
raw = result.stdout.decode("utf-8", errors="ignore").strip()
if not raw:
return None
try:
duration = float(raw)
except ValueError:
return None
if duration <= 0:
return None
return duration
def _run_ffmpeg_merge(video_path: str, audio_path: str, output_path: str):
audio_duration = _probe_media_duration(audio_path)
command = [
_ffmpeg_bin(),
2026-01-31 19:15:41 +08:00
"-y",
"-loglevel",
"error",
"-stream_loop",
"-1",
"-i",
video_path,
"-i",
audio_path,
]
if audio_duration:
command.extend(["-t", f"{audio_duration:.3f}"])
command += [
"-map",
"0:v:0",
"-map",
"1:a:0",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-shortest",
"-movflags",
"+faststart",
output_path,
]
try:
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"ffmpeg 合成失败: {stderr[:200]}") from exc
def _extract_audio_segment(
input_path: str,
start_sec: float,
duration_sec: float,
output_path: str,
):
command = [
_ffmpeg_bin(),
2026-01-31 19:15:41 +08:00
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-ss",
f"{start_sec:.3f}",
"-t",
f"{duration_sec:.3f}",
"-vn",
"-acodec",
"libmp3lame",
"-b:a",
"128k",
"-ar",
"44100",
"-ac",
"2",
output_path,
]
try:
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"音频分段失败: {stderr[:200]}") from exc
def _pad_audio_segment(
input_path: str,
pad_sec: float,
target_duration_sec: float,
output_path: str,
):
if pad_sec <= 0:
return
command = [
_ffmpeg_bin(),
2026-01-31 19:15:41 +08:00
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-af",
f"apad=pad_dur={pad_sec:.3f}",
"-t",
f"{target_duration_sec:.3f}",
"-acodec",
"libmp3lame",
"-b:a",
"128k",
"-ar",
"44100",
"-ac",
"2",
output_path,
]
try:
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"音频补齐失败: {stderr[:200]}") from exc
def _pick_random_bgm(db: Session) -> SongLibrary:
base_query = (
db.query(SongLibrary)
.filter(
SongLibrary.status.is_(True),
SongLibrary.deletetime.is_(None),
SongLibrary.audio_url.isnot(None),
SongLibrary.audio_url != "",
)
)
total = base_query.count()
if total <= 0:
raise HTTPException(status_code=404, detail="音频库暂无可用音频")
offset = random.randrange(total)
song = base_query.offset(offset).limit(1).first()
if not song or not song.audio_url:
raise HTTPException(status_code=404, detail="音频库暂无可用音频")
return song
def _merge_dance_video_with_bgm(
base_video_url: str,
audio_url: str,
target_duration_sec: int,
) -> tuple[bytes, dict]:
with tempfile.TemporaryDirectory() as tmpdir:
base_video_path = os.path.join(tmpdir, "base.mp4")
audio_source_path = os.path.join(tmpdir, "audio_source.mp3")
_download_to_path(base_video_url, base_video_path, "跳舞视频")
resolved_audio_url = _cdnize(audio_url) or audio_url
_download_to_path(resolved_audio_url, audio_source_path, "BGM音频")
audio_duration = _probe_media_duration(audio_source_path)
if not audio_duration:
raise HTTPException(status_code=502, detail="音频时长解析失败")
start_sec = 0.0
clip_duration = min(audio_duration, float(target_duration_sec))
if audio_duration > target_duration_sec:
max_start = max(0.0, audio_duration - target_duration_sec)
start_sec = random.uniform(0.0, max_start) if max_start > 0 else 0.0
clip_duration = float(target_duration_sec)
clip_path = os.path.join(tmpdir, "audio_clip.mp3")
_extract_audio_segment(audio_source_path, start_sec, clip_duration, clip_path)
if clip_duration < target_duration_sec:
padded_path = os.path.join(tmpdir, "audio_clip_pad.mp3")
_pad_audio_segment(
clip_path,
target_duration_sec - clip_duration,
target_duration_sec,
padded_path,
)
clip_path = padded_path
output_path = os.path.join(tmpdir, "merged.mp4")
_run_ffmpeg_merge(base_video_path, clip_path, output_path)
with open(output_path, "rb") as file_handle:
merged_bytes = file_handle.read()
return merged_bytes, {
"bgm_start_sec": round(start_sec, 3),
"bgm_duration": target_duration_sec,
}
def _get_or_create_session(db: Session, user: AuthedUser, lover: Lover, session_id: Optional[int]) -> ChatSession:
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
return session
active_sessions = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
.with_for_update()
.order_by(ChatSession.created_at.desc())
.all()
)
if active_sessions:
primary = active_sessions[0]
for extra in active_sessions[1:]:
if extra.status == "active":
extra.status = "archived"
db.add(extra)
return primary
now = datetime.utcnow()
session = ChatSession(
user_id=user.id,
lover_id=lover.id,
model=settings.LLM_MODEL or "qwen-flash",
status="active",
last_message_at=now,
created_at=now,
updated_at=now,
inner_voice_enabled=False,
)
db.add(session)
db.flush()
return session
def _next_seq(db: Session, session_id: int) -> int:
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session_id)
.order_by(ChatMessage.seq.desc())
.first()
)
return (last_msg.seq if last_msg and last_msg.seq else 0) + 1
def _mark_task_failed(db: Session, task_id: int, msg: str):
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task:
return
payload = task.payload or {}
lover_msg_id = payload.get("lover_message_id")
session_id = payload.get("session_id")
if lover_msg_id:
lover_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == lover_msg_id)
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = f"跳舞视频生成失败:{msg}"
lover_msg.extra = {
**(lover_msg.extra or {}),
"generation_status": "failed",
"error_msg": msg,
"generation_task_id": task_id,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
db.add(lover_msg)
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id)
.with_for_update()
.first()
)
if session:
session.last_message_at = datetime.utcnow()
db.add(session)
task.status = "failed"
task.error_msg = msg[:255]
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
def _process_dance_task(task_id: int):
# Phase 1: 读取任务并标记 running快速退出连接
task = None
prompt = ""
image_url = ""
session_id = None
user_message_id = None
lover_message_id = None
dash_task_id = None
user_id = None
lover_id = None
bgm_song_id = None
bgm_audio_url_raw = None
bgm_audio_url = None
bgm_start_sec = None
try:
db = SessionLocal()
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task or task.status in ("succeeded", "failed"):
db.rollback()
return
user_id = task.user_id
lover_id = task.lover_id
payload = task.payload or {}
prompt = payload.get("prompt") or ""
image_url = payload.get("image_url")
session_id = payload.get("session_id")
user_message_id = payload.get("user_message_id")
lover_message_id = payload.get("lover_message_id")
dash_task_id = payload.get("dashscope_task_id")
# 基础存在性校验,快速失败
lover = db.query(Lover).filter(Lover.id == lover_id).first()
user_row = db.query(User).filter(User.id == user_id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在,请重新创建")
if not lover.image_url:
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
if (user_row.video_gen_remaining or 0) <= 0:
raise HTTPException(status_code=400, detail="视频生成次数不足")
image_url = image_url or lover.image_url
task.status = "running"
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
except HTTPException as exc:
try:
_mark_task_failed(
db,
task_id,
str(exc.detail) if hasattr(exc, "detail") else str(exc),
)
except Exception:
pass
finally:
try:
db.close()
except Exception:
pass
return
except Exception as exc:
try:
_mark_task_failed(db, task_id, str(exc)[:255])
except Exception:
pass
finally:
try:
db.close()
except Exception:
pass
return
finally:
try:
db.close()
except Exception:
pass
# Phase 2: 提交/轮询 DashScope无 DB 连接)
try:
if not dash_task_id:
dash_task_id = _submit_video_task(prompt, image_url)
with SessionLocal() as db_tmp:
task_row = (
db_tmp.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if task_row:
task_row.payload = {**(task_row.payload or {}), "dashscope_task_id": dash_task_id}
task_row.status = "running"
task_row.updated_at = datetime.utcnow()
db_tmp.add(task_row)
db_tmp.commit()
dash_video_url = _poll_video_url(dash_task_id)
with SessionLocal() as db_tmp:
bgm_song = _pick_random_bgm(db_tmp)
bgm_song_id = bgm_song.id
bgm_audio_url_raw = bgm_song.audio_url
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
task_row = (
db_tmp.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if task_row:
task_row.payload = {
**(task_row.payload or {}),
"bgm_song_id": bgm_song_id,
"bgm_audio_url": bgm_audio_url,
"bgm_audio_url_raw": bgm_audio_url_raw,
}
task_row.updated_at = datetime.utcnow()
db_tmp.add(task_row)
db_tmp.commit()
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
dash_video_url,
bgm_audio_url,
DANCE_TARGET_DURATION_SEC,
)
bgm_start_sec = bgm_meta.get("bgm_start_sec")
safe_prompt_tag = "prompt"
object_name = f"lover/{lover_id}/dance/{int(time.time())}_{safe_prompt_tag}.mp4"
oss_url = _upload_to_oss(merged_bytes, object_name)
except Exception as exc:
try:
db_fail = SessionLocal()
_mark_task_failed(db_fail, task_id, str(exc) if not hasattr(exc, "detail") else str(exc.detail))
except Exception:
pass
return
# Phase 3: 回写结果与消息(短事务)
db = SessionLocal()
try:
task_row = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task_row:
db.rollback()
return
lover = db.query(Lover).filter(Lover.id == task_row.lover_id).first()
user_row = (
db.query(User)
.filter(User.id == task_row.user_id)
.with_for_update()
.first()
)
if not lover or not user_row:
raise HTTPException(status_code=404, detail="用户或恋人不存在")
# 获取/创建会话
session = None
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user_row.id)
.with_for_update()
.first()
)
if not session:
session = _get_or_create_session(
db,
AuthedUser(
id=user_row.id,
reg_step=user_row.reg_step or 0,
gender=user_row.gender or 0,
nickname=user_row.nickname or user_row.username or "",
token=user_row.token or "",
),
lover,
None,
)
# 用户消息
user_msg = None
if user_message_id:
user_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == user_message_id, ChatMessage.session_id == session.id)
.with_for_update()
.first()
)
if not user_msg:
now = datetime.utcnow()
next_seq = _next_seq(db, session.id)
user_msg = ChatMessage(
session_id=session.id,
user_id=user_row.id,
lover_id=lover.id,
role="user",
content_type="text",
content=prompt,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
next_seq = user_msg.seq or _next_seq(db, session.id)
# 恋人消息
lover_msg = None
if lover_message_id:
lover_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == lover_message_id, ChatMessage.session_id == session.id)
.with_for_update()
.first()
)
if not lover_msg:
lover_msg = ChatMessage(
session_id=session.id,
user_id=user_row.id,
lover_id=lover.id,
role="lover",
content_type="text",
content="正在为你生成跳舞视频,完成后会自动更新此消息",
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"generation_task_id": task_row.id,
"generation_status": "pending",
"prompt": prompt,
},
tts_status="pending",
)
db.add(lover_msg)
db.flush()
lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{oss_url}"
lover_msg.extra = {
**(lover_msg.extra or {}),
"video_url": oss_url,
"dashscope_video_url": dash_video_url,
"dashscope_task_id": dash_task_id,
"generation_task_id": task_row.id,
"prompt": prompt,
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
"resolution": settings.VIDEO_GEN_RESOLUTION or "480P",
"duration": DANCE_TARGET_DURATION_SEC,
"base_duration": settings.VIDEO_GEN_DURATION or 5,
"watermark": False,
"generation_status": "succeeded",
"bgm_song_id": bgm_song_id,
"bgm_audio_url": bgm_audio_url,
"bgm_audio_url_raw": bgm_audio_url_raw,
"bgm_start_sec": bgm_start_sec,
"bgm_duration": DANCE_TARGET_DURATION_SEC,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
db.add(lover_msg)
# 扣减额度并刷新会话时间(不允许扣成负数,只扣一次)
already_deducted = (task_row.payload or {}).get("deducted")
remaining = user_row.video_gen_remaining or 0
if remaining > 0 and not already_deducted:
user_row.video_gen_remaining = remaining - 1
session.last_message_at = datetime.utcnow()
db.add(user_row)
db.add(session)
task_row.status = "succeeded"
task_row.result_url = oss_url
task_row.payload = {
**(task_row.payload or {}),
"dashscope_video_url": dash_video_url,
"dashscope_task_id": dash_task_id,
"session_id": session.id,
"user_message_id": user_msg.id,
"lover_message_id": lover_msg.id,
"deducted": True,
"output_duration": DANCE_TARGET_DURATION_SEC,
"bgm_song_id": bgm_song_id,
"bgm_audio_url": bgm_audio_url,
"bgm_audio_url_raw": bgm_audio_url_raw,
"bgm_start_sec": bgm_start_sec,
"bgm_duration": DANCE_TARGET_DURATION_SEC,
}
task_row.updated_at = datetime.utcnow()
db.add(task_row)
db.commit()
except HTTPException as exc:
db.rollback()
try:
_mark_task_failed(
db,
task_id,
str(exc.detail) if hasattr(exc, "detail") else str(exc),
)
except Exception:
pass
except Exception as exc:
db.rollback()
try:
_mark_task_failed(db, task_id, str(exc)[:255])
except Exception:
pass
finally:
db.close()
@router.post("/generate", response_model=ApiResponse[DanceTaskStatusOut])
def generate_dance_video(
payload: DanceGenerateIn,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
prompt = (payload.prompt or "").strip()
if not prompt:
raise HTTPException(status_code=400, detail="请输入想要跳的舞/唱的歌/动作描述")
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在,请先完成创建流程")
if not lover.image_url:
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
if (user_row.video_gen_remaining or 0) <= 0:
raise HTTPException(status_code=400, detail="视频生成次数不足")
pending_task = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
GenerationTask.status.in_(["pending", "running"]),
)
.first()
)
if pending_task:
raise HTTPException(status_code=409, detail="已有视频生成任务进行中,请稍后再试")
idem_key_src = f"video:{user.id}:{lover.image_url}:{prompt}"
idem_key = hashlib.sha256(idem_key_src.encode("utf-8")).hexdigest()
task = GenerationTask(
user_id=user.id,
lover_id=lover.id,
task_type="video",
status="pending",
idempotency_key=idem_key,
payload={
"image_url": lover.image_url,
"prompt": prompt,
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
"resolution": settings.VIDEO_GEN_RESOLUTION or "480P",
"duration": settings.VIDEO_GEN_DURATION or 5,
"watermark": False,
"output_duration": DANCE_TARGET_DURATION_SEC,
},
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(task)
try:
db.flush()
except IntegrityError:
db.rollback()
existing = (
db.query(GenerationTask)
.filter(GenerationTask.idempotency_key == idem_key)
.first()
)
if existing and existing.status in ("pending", "running"):
raise HTTPException(status_code=409, detail="已提交相同的视频生成任务,请稍后查看结果")
# 既往任务已结束,避免幂等键冲突,生成新的短哈希重试
retry_key = hashlib.sha256(f"{idem_key}:{time.time()}".encode()).hexdigest()
task.idempotency_key = retry_key
db.add(task)
db.flush()
# 预留会话与消息序号,便于生成完成后不会被后续消息“插队”
session = _get_or_create_session(db, user, lover, None)
now = datetime.utcnow()
next_seq = _next_seq(db, session.id)
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="text",
content=prompt,
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content="正在为你生成跳舞视频,完成后会自动更新此消息",
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"generation_task_id": task.id,
"generation_status": "pending",
"prompt": prompt,
},
tts_status="pending",
)
db.add(lover_msg)
db.flush()
session.last_message_at = datetime.utcnow()
db.add(session)
task.payload = {
**(task.payload or {}),
"session_id": session.id,
"user_message_id": user_msg.id,
"lover_message_id": lover_msg.id,
}
db.add(task)
db.commit()
background_tasks.add_task(_process_dance_task, task.id)
return success_response(
DanceTaskStatusOut(
generation_task_id=task.id,
status="pending",
dashscope_task_id="",
video_url="",
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
error_msg=None,
),
msg="视频生成任务已提交,正在生成",
)
@router.get("/generate/{task_id}", response_model=ApiResponse[DanceTaskStatusOut])
def get_dance_task(
task_id: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
task = (
db.query(GenerationTask)
.filter(
GenerationTask.id == task_id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
status_msg_map = {
"pending": "视频生成中",
"running": "视频生成中",
"succeeded": "视频生成成功",
"failed": "视频生成失败",
}
resp_msg = status_msg_map.get(task.status or "", "查询成功")
# 若状态仍为 pending/running补偿性触发处理同步 + 异步双保险)
if task.status in ("pending", "running"):
# 避免阻塞,优先异步处理
background_tasks.add_task(_process_dance_task, task.id)
# 用新会话读取最新状态,避免当前事务快照阻塞
with SessionLocal() as tmp:
latest = (
tmp.query(GenerationTask)
.filter(
GenerationTask.id == task.id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
)
.first()
)
if latest:
task = latest
# 若仍然卡在 running直接查询 dashscope并回填结果/失败,避免无限等待
if task.status in ("pending", "running") and (task.payload or {}).get("dashscope_task_id"):
status, dash_video = _fetch_dashscope_status((task.payload or {}).get("dashscope_task_id"))
with SessionLocal() as tmp:
current = (
tmp.query(GenerationTask)
.filter(
GenerationTask.id == task.id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
)
.with_for_update()
.first()
)
if current:
if status == "SUCCEEDED":
dash_url = (
dash_video
or (current.payload or {}).get("dashscope_video_url")
or current.result_url
)
final_url = current.result_url or dash_url
bgm_song_id = (current.payload or {}).get("bgm_song_id")
bgm_audio_url = (current.payload or {}).get("bgm_audio_url")
bgm_start_sec = (current.payload or {}).get("bgm_start_sec")
merge_failed = False
if dash_url and not _is_own_oss_url(final_url or ""):
try:
bgm_song = _pick_random_bgm(tmp)
bgm_song_id = bgm_song.id
bgm_audio_url_raw = bgm_song.audio_url
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
dash_url,
bgm_audio_url,
DANCE_TARGET_DURATION_SEC,
)
bgm_start_sec = bgm_meta.get("bgm_start_sec")
object_name = f"lover/{current.lover_id}/dance/{int(time.time())}_prompt.mp4"
final_url = _upload_to_oss(merged_bytes, object_name)
current.payload = {
**(current.payload or {}),
"bgm_song_id": bgm_song_id,
"bgm_audio_url": bgm_audio_url,
"bgm_audio_url_raw": bgm_audio_url_raw,
}
except Exception as exc:
current.status = "failed"
current.error_msg = str(exc) if not hasattr(exc, "detail") else str(exc.detail)
current.updated_at = datetime.utcnow()
tmp.add(current)
tmp.commit()
task = current
merge_failed = True
if not merge_failed:
current.status = "succeeded"
current.result_url = final_url
current.payload = {
**(current.payload or {}),
"dashscope_video_url": dash_url or (current.payload or {}).get("dashscope_video_url"),
"output_duration": DANCE_TARGET_DURATION_SEC,
"bgm_song_id": bgm_song_id,
"bgm_audio_url": bgm_audio_url,
"bgm_start_sec": bgm_start_sec,
"bgm_duration": DANCE_TARGET_DURATION_SEC,
}
current.updated_at = datetime.utcnow()
# 尝试同步更新聊天占位消息与额度(避免重复扣,检查 deducted 标记)
try:
lover = tmp.query(Lover).filter(Lover.id == current.lover_id).first()
user_row = (
tmp.query(User)
.filter(User.id == current.user_id)
.with_for_update()
.first()
)
session = None
if current.payload.get("session_id"):
session = (
tmp.query(ChatSession)
.filter(
ChatSession.id == current.payload.get("session_id"),
ChatSession.user_id == current.user_id,
)
.with_for_update()
.first()
)
lover_msg = None
if current.payload.get("lover_message_id"):
lover_msg = (
tmp.query(ChatMessage)
.filter(ChatMessage.id == current.payload.get("lover_message_id"))
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{final_url}"
lover_msg.extra = {
**(lover_msg.extra or {}),
"video_url": final_url,
"dashscope_video_url": dash_url,
"dashscope_task_id": current.payload.get("dashscope_task_id"),
"generation_task_id": current.id,
"prompt": current.payload.get("prompt"),
"model": current.payload.get("model") or settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
"resolution": current.payload.get("resolution") or settings.VIDEO_GEN_RESOLUTION or "480P",
"duration": current.payload.get("output_duration") or DANCE_TARGET_DURATION_SEC,
"base_duration": current.payload.get("duration") or settings.VIDEO_GEN_DURATION or 5,
"watermark": current.payload.get("watermark", False),
"generation_status": "succeeded",
"bgm_song_id": current.payload.get("bgm_song_id"),
"bgm_audio_url": current.payload.get("bgm_audio_url"),
"bgm_audio_url_raw": current.payload.get("bgm_audio_url_raw"),
"bgm_start_sec": current.payload.get("bgm_start_sec"),
"bgm_duration": current.payload.get("bgm_duration") or DANCE_TARGET_DURATION_SEC,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
tmp.add(lover_msg)
if user_row and not (current.payload or {}).get("deducted"):
remaining = user_row.video_gen_remaining or 0
if remaining > 0:
user_row.video_gen_remaining = remaining - 1
current.payload = {**(current.payload or {}), "deducted": True}
tmp.add(user_row)
if session:
session.last_message_at = datetime.utcnow()
tmp.add(session)
except Exception:
# 避免阻断主流程
pass
tmp.add(current)
tmp.commit()
task = current
elif status == "FAILED":
current.status = "failed"
current.error_msg = "视频生成失败DashScope 返回失败)"
current.updated_at = datetime.utcnow()
tmp.add(current)
tmp.commit()
task = current
# 用新会话读取最终状态,避免 DetachedInstanceError
with SessionLocal() as snap:
latest = (
snap.query(GenerationTask)
.filter(
GenerationTask.id == task_id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
)
.first()
)
if not latest:
raise HTTPException(status_code=404, detail="任务不存在")
payload = latest.payload or {}
resp_msg = status_msg_map.get(latest.status or "", resp_msg)
return success_response(
DanceTaskStatusOut(
generation_task_id=latest.id,
status=latest.status,
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
video_url=latest.result_url or payload.get("dashscope_video_url") or "",
session_id=int(payload.get("session_id") or 0),
user_message_id=int(payload.get("user_message_id") or 0),
lover_message_id=int(payload.get("lover_message_id") or 0),
error_msg=latest.error_msg,
),
msg=resp_msg,
)