1489 lines
53 KiB
Python
1489 lines
53 KiB
Python
import hashlib
|
||
import os
|
||
import random
|
||
import shutil
|
||
import subprocess
|
||
import tempfile
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
import oss2
|
||
import requests
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy.exc import IntegrityError
|
||
|
||
from ..config import settings
|
||
from ..db import SessionLocal, get_db
|
||
from ..deps import AuthedUser, get_current_user
|
||
from ..models import (
|
||
ChatMessage,
|
||
ChatSession,
|
||
GenerationTask,
|
||
Lover,
|
||
SongLibrary,
|
||
User,
|
||
)
|
||
from ..response import ApiResponse, success_response
|
||
|
||
try:
|
||
import imageio_ffmpeg # type: ignore
|
||
except Exception: # pragma: no cover
|
||
imageio_ffmpeg = None
|
||
|
||
router = APIRouter(prefix="/dance", tags=["dance"])
|
||
|
||
|
||
def _ffmpeg_bin() -> str:
|
||
found = shutil.which("ffmpeg")
|
||
if found:
|
||
return found
|
||
if imageio_ffmpeg is not None:
|
||
try:
|
||
return imageio_ffmpeg.get_ffmpeg_exe()
|
||
except Exception:
|
||
pass
|
||
return "ffmpeg"
|
||
|
||
|
||
def _ffprobe_bin() -> str:
|
||
found = shutil.which("ffprobe")
|
||
if found:
|
||
return found
|
||
if imageio_ffmpeg is not None:
|
||
try:
|
||
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
|
||
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe.exe")
|
||
if os.path.exists(candidate):
|
||
return candidate
|
||
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe")
|
||
if os.path.exists(candidate):
|
||
return candidate
|
||
except Exception:
|
||
pass
|
||
return "ffprobe"
|
||
|
||
DANCE_TARGET_DURATION_SEC = 10
|
||
|
||
|
||
@router.get("/history", response_model=ApiResponse[list[dict]])
|
||
def get_dance_history(
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
page: int = 1,
|
||
size: int = 20,
|
||
):
|
||
"""
|
||
获取用户的跳舞视频历史记录。
|
||
"""
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
|
||
offset = (page - 1) * size
|
||
tasks = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.lover_id == lover.id,
|
||
GenerationTask.task_type == "video",
|
||
GenerationTask.payload["prompt"].as_string().isnot(None),
|
||
GenerationTask.status == "succeeded",
|
||
GenerationTask.result_url.isnot(None),
|
||
)
|
||
.order_by(GenerationTask.id.desc())
|
||
.offset(offset)
|
||
.limit(size)
|
||
.all()
|
||
)
|
||
|
||
result: list[dict] = []
|
||
for task in tasks:
|
||
payload = task.payload or {}
|
||
result.append(
|
||
{
|
||
"id": task.id,
|
||
"prompt": payload.get("prompt") or "",
|
||
"video_url": _cdnize(task.result_url) or "",
|
||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||
}
|
||
)
|
||
|
||
return success_response(result, msg="获取成功")
|
||
|
||
|
||
def _download_binary(url: str) -> bytes:
|
||
try:
|
||
resp = requests.get(url, timeout=30)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail="文件下载失败") from exc
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=502, detail="文件下载失败")
|
||
return resp.content
|
||
|
||
|
||
def _retry_finalize_dance_task(task_id: int) -> None:
|
||
"""任务失败但 DashScope 端已成功时,尝试重新下载视频并重新上传(自愈/手动重试共用)。"""
|
||
try:
|
||
with SessionLocal() as db:
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not task:
|
||
return
|
||
payload = task.payload or {}
|
||
dash_id = payload.get("dashscope_task_id")
|
||
if not dash_id:
|
||
return
|
||
|
||
status, dash_video_url = _fetch_dashscope_status(str(dash_id))
|
||
if status != "SUCCEEDED" or not dash_video_url:
|
||
return
|
||
|
||
# 重新生成:下载 dashscope 视频 -> 随机 BGM -> 合成 -> 上传
|
||
bgm_song = _pick_random_bgm(db)
|
||
bgm_audio_url_raw = bgm_song.audio_url
|
||
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
|
||
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
|
||
dash_video_url,
|
||
bgm_audio_url,
|
||
DANCE_TARGET_DURATION_SEC,
|
||
)
|
||
object_name = f"lover/{task.lover_id}/dance/{int(time.time())}_retry.mp4"
|
||
oss_url = _upload_to_oss(merged_bytes, object_name)
|
||
|
||
task.status = "succeeded"
|
||
task.result_url = oss_url
|
||
task.error_msg = None
|
||
task.payload = {
|
||
**payload,
|
||
"dashscope_video_url": dash_video_url,
|
||
"bgm_song_id": bgm_song.id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_audio_url_raw": bgm_audio_url_raw,
|
||
"bgm_start_sec": bgm_meta.get("bgm_start_sec"),
|
||
"bgm_duration": DANCE_TARGET_DURATION_SEC,
|
||
"retry_finalized": True,
|
||
}
|
||
task.updated_at = datetime.utcnow()
|
||
db.add(task)
|
||
db.commit()
|
||
except Exception:
|
||
return
|
||
|
||
|
||
@router.post("/retry/{task_id}", response_model=ApiResponse[dict])
|
||
def retry_dance_task(
|
||
task_id: int,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
"""手动重试:用于 DashScope 端已成功但本地下载/合成/上传失败导致任务失败的情况。"""
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.id == task_id,
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.lover_id == lover.id,
|
||
GenerationTask.task_type == "video",
|
||
GenerationTask.payload["prompt"].as_string().isnot(None),
|
||
)
|
||
.first()
|
||
)
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
payload = task.payload or {}
|
||
dash_id = payload.get("dashscope_task_id")
|
||
if not dash_id:
|
||
raise HTTPException(status_code=400, detail="任务缺少 dashscope_task_id,无法重试")
|
||
|
||
task.payload = {**payload, "manual_retry": True}
|
||
task.updated_at = datetime.utcnow()
|
||
db.add(task)
|
||
db.commit()
|
||
|
||
background_tasks.add_task(_retry_finalize_dance_task, int(task.id))
|
||
return success_response({"task_id": int(task.id)}, msg="已触发重试")
|
||
|
||
|
||
@router.get("/history/all", response_model=ApiResponse[list[dict]])
|
||
def get_dance_history_all(
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
page: int = 1,
|
||
size: int = 20,
|
||
):
|
||
"""获取用户的跳舞视频全部历史记录(成功+失败+进行中)。"""
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人未找到")
|
||
|
||
offset = (page - 1) * size
|
||
tasks = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.lover_id == lover.id,
|
||
GenerationTask.task_type == "video",
|
||
GenerationTask.payload["prompt"].as_string().isnot(None),
|
||
)
|
||
.order_by(GenerationTask.id.desc())
|
||
.offset(offset)
|
||
.limit(size)
|
||
.all()
|
||
)
|
||
|
||
result: list[dict] = []
|
||
for task in tasks:
|
||
payload = task.payload or {}
|
||
url = task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or ""
|
||
result.append(
|
||
{
|
||
"id": int(task.id),
|
||
"prompt": payload.get("prompt") or "",
|
||
"status": task.status,
|
||
"video_url": _cdnize(url) or "",
|
||
"error_msg": task.error_msg,
|
||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||
}
|
||
)
|
||
|
||
return success_response(result, msg="获取成功")
|
||
|
||
|
||
class DanceGenerateIn(BaseModel):
|
||
prompt: str = Field(..., min_length=2, max_length=400, description="用户希望跳的舞/动作描述")
|
||
|
||
|
||
class DanceTaskStatusOut(BaseModel):
|
||
generation_task_id: int
|
||
status: str = Field(..., description="pending|running|succeeded|failed")
|
||
dashscope_task_id: str = ""
|
||
video_url: str = ""
|
||
session_id: int = 0
|
||
user_message_id: int = 0
|
||
lover_message_id: int = 0
|
||
error_msg: Optional[str] = None
|
||
|
||
|
||
@router.get("/current", response_model=ApiResponse[Optional[DanceTaskStatusOut]])
|
||
def get_current_dance_task(
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
GenerationTask.status.in_(["pending", "running"]),
|
||
GenerationTask.payload["prompt"].as_string().isnot(None),
|
||
)
|
||
.order_by(GenerationTask.id.desc())
|
||
.first()
|
||
)
|
||
if not task:
|
||
return success_response(None, msg="暂无进行中的任务")
|
||
|
||
payload = task.payload or {}
|
||
return success_response(
|
||
DanceTaskStatusOut(
|
||
generation_task_id=task.id,
|
||
status=task.status,
|
||
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
|
||
video_url=task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or "",
|
||
session_id=int(payload.get("session_id") or 0),
|
||
user_message_id=int(payload.get("user_message_id") or 0),
|
||
lover_message_id=int(payload.get("lover_message_id") or 0),
|
||
error_msg=task.error_msg,
|
||
),
|
||
msg="获取成功",
|
||
)
|
||
|
||
|
||
def _upload_to_oss(file_bytes: bytes, object_name: str) -> str:
|
||
if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET:
|
||
raise HTTPException(status_code=500, detail="未配置 OSS Key")
|
||
if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT:
|
||
raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint")
|
||
|
||
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
|
||
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
|
||
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
|
||
bucket.put_object(object_name, file_bytes)
|
||
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
|
||
if cdn:
|
||
return f"{cdn.rstrip('/')}/{object_name}"
|
||
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
|
||
|
||
|
||
def _is_own_oss_url(url: str) -> bool:
|
||
if not url:
|
||
return False
|
||
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
|
||
if cdn and url.startswith(cdn.rstrip("/")):
|
||
return True
|
||
bucket = settings.ALIYUN_OSS_BUCKET_NAME
|
||
endpoint = (settings.ALIYUN_OSS_ENDPOINT or "").rstrip("/")
|
||
if bucket and endpoint:
|
||
domain = endpoint.replace("https://", "").replace("http://", "")
|
||
return url.startswith(f"https://{bucket}.{domain}/")
|
||
return False
|
||
|
||
|
||
def _cdnize(url: Optional[str]) -> Optional[str]:
|
||
"""
|
||
将相对路径补全为可访问 URL。优先使用 CDN,其次 bucket+endpoint,最后兜底固定域名。
|
||
"""
|
||
if not url:
|
||
return url
|
||
cleaned = url.strip()
|
||
if cleaned.startswith("http://") or cleaned.startswith("https://"):
|
||
return cleaned
|
||
cleaned = cleaned.lstrip("/")
|
||
if settings.ALIYUN_OSS_CDN_DOMAIN:
|
||
return f"{settings.ALIYUN_OSS_CDN_DOMAIN.rstrip('/')}/{cleaned}"
|
||
if settings.ALIYUN_OSS_BUCKET_NAME and settings.ALIYUN_OSS_ENDPOINT:
|
||
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/").replace("https://", "").replace("http://", "")
|
||
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint}/{cleaned}"
|
||
return f"https://nvlovers.oss-cn-qingdao.aliyuncs.com/{cleaned}"
|
||
|
||
|
||
def _shorten_url(url: str, max_len: int = 160) -> str:
|
||
if len(url) <= max_len:
|
||
return url
|
||
return f"{url[: max_len - 3]}..."
|
||
|
||
|
||
def _submit_video_task(prompt: str, image_url: str) -> str:
|
||
if not settings.DASHSCOPE_API_KEY:
|
||
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
|
||
|
||
resolution = settings.VIDEO_GEN_RESOLUTION or "480P"
|
||
duration = settings.VIDEO_GEN_DURATION or 5
|
||
payload = {
|
||
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
|
||
"input": {
|
||
"prompt": prompt,
|
||
"img_url": image_url,
|
||
},
|
||
"parameters": {
|
||
"resolution": resolution,
|
||
"duration": duration,
|
||
"watermark": False,
|
||
"prompt_extend": True,
|
||
},
|
||
}
|
||
|
||
headers = {
|
||
"X-DashScope-Async": "enable",
|
||
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
try:
|
||
resp = requests.post(
|
||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis",
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=10,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail="调用视频生成接口失败") from exc
|
||
|
||
if resp.status_code != 200:
|
||
msg = resp.text
|
||
try:
|
||
msg = resp.json().get("message") or msg
|
||
except Exception:
|
||
pass
|
||
raise HTTPException(status_code=502, detail=f"视频任务提交失败: {msg}")
|
||
|
||
try:
|
||
data = resp.json()
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail="视频任务返回解析失败") from exc
|
||
task_id = (
|
||
data.get("output", {}).get("task_id")
|
||
or data.get("task_id")
|
||
or data.get("output", {}).get("id")
|
||
)
|
||
if not task_id:
|
||
raise HTTPException(status_code=502, detail="视频任务未返回 task_id")
|
||
return str(task_id)
|
||
|
||
|
||
def _poll_video_url(task_id: str) -> str:
|
||
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
|
||
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||
deadline = time.time() + 120
|
||
while time.time() < deadline:
|
||
time.sleep(3)
|
||
try:
|
||
resp = requests.get(query_url, headers=headers, timeout=8)
|
||
except Exception:
|
||
continue
|
||
if resp.status_code != 200:
|
||
continue
|
||
try:
|
||
data = resp.json()
|
||
except Exception:
|
||
continue
|
||
output = data.get("output") or {}
|
||
status_str = str(
|
||
output.get("task_status")
|
||
or data.get("task_status")
|
||
or data.get("status")
|
||
or ""
|
||
).upper()
|
||
if status_str == "SUCCEEDED":
|
||
url = output.get("video_url") or data.get("video_url")
|
||
if not url:
|
||
raise HTTPException(status_code=502, detail="视频生成成功但未返回结果 URL")
|
||
return url
|
||
if status_str == "FAILED":
|
||
msg = output.get("message") or data.get("message") or "生成失败"
|
||
raise HTTPException(status_code=502, detail=f"视频生成失败: {msg}")
|
||
raise HTTPException(status_code=504, detail="视频生成超时,请稍后重试")
|
||
|
||
|
||
def _fetch_dashscope_status(task_id: str) -> tuple[str, Optional[str]]:
|
||
"""
|
||
直接查询 dashscope 任务状态。
|
||
返回 (status, video_url|None);status 取 SUCCEEDED/FAILED/其他。
|
||
"""
|
||
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
|
||
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||
try:
|
||
resp = requests.get(query_url, headers=headers, timeout=8)
|
||
except Exception:
|
||
return "UNKNOWN", None
|
||
if resp.status_code != 200:
|
||
return "UNKNOWN", None
|
||
try:
|
||
data = resp.json()
|
||
except Exception:
|
||
return "UNKNOWN", None
|
||
output = data.get("output") or {}
|
||
status_str = str(
|
||
output.get("task_status")
|
||
or data.get("task_status")
|
||
or data.get("status")
|
||
or ""
|
||
).upper()
|
||
video_url = output.get("video_url") or data.get("video_url")
|
||
return status_str, video_url
|
||
|
||
|
||
def _download_to_path(url: str, target_path: str, label: str):
|
||
if not url:
|
||
raise HTTPException(status_code=502, detail=f"{label}下载失败: 空URL")
|
||
try:
|
||
resp = requests.get(url, stream=True, timeout=30)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail=f"{label}下载失败: {_shorten_url(url)}") from exc
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=502, detail=f"{label}下载失败({resp.status_code}): {_shorten_url(url)}")
|
||
try:
|
||
with open(target_path, "wb") as file_handle:
|
||
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
||
if chunk:
|
||
file_handle.write(chunk)
|
||
finally:
|
||
resp.close()
|
||
|
||
|
||
def _probe_media_duration(path: str) -> Optional[float]:
|
||
command = [
|
||
_ffprobe_bin(),
|
||
"-v",
|
||
"error",
|
||
"-show_entries",
|
||
"format=duration",
|
||
"-of",
|
||
"default=noprint_wrappers=1:nokey=1",
|
||
path,
|
||
]
|
||
try:
|
||
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
except FileNotFoundError:
|
||
return None
|
||
except subprocess.CalledProcessError:
|
||
return None
|
||
raw = result.stdout.decode("utf-8", errors="ignore").strip()
|
||
if not raw:
|
||
return None
|
||
try:
|
||
duration = float(raw)
|
||
except ValueError:
|
||
return None
|
||
if duration <= 0:
|
||
return None
|
||
return duration
|
||
|
||
|
||
def _run_ffmpeg_merge(video_path: str, audio_path: str, output_path: str):
|
||
audio_duration = _probe_media_duration(audio_path)
|
||
command = [
|
||
_ffmpeg_bin(),
|
||
"-y",
|
||
"-loglevel",
|
||
"error",
|
||
"-stream_loop",
|
||
"-1",
|
||
"-i",
|
||
video_path,
|
||
"-i",
|
||
audio_path,
|
||
]
|
||
if audio_duration:
|
||
command.extend(["-t", f"{audio_duration:.3f}"])
|
||
command += [
|
||
"-map",
|
||
"0:v:0",
|
||
"-map",
|
||
"1:a:0",
|
||
"-c:v",
|
||
"libx264",
|
||
"-preset",
|
||
"veryfast",
|
||
"-crf",
|
||
"23",
|
||
"-pix_fmt",
|
||
"yuv420p",
|
||
"-c:a",
|
||
"aac",
|
||
"-shortest",
|
||
"-movflags",
|
||
"+faststart",
|
||
output_path,
|
||
]
|
||
try:
|
||
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
except FileNotFoundError as exc:
|
||
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
|
||
except subprocess.CalledProcessError as exc:
|
||
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
|
||
raise HTTPException(status_code=502, detail=f"ffmpeg 合成失败: {stderr[:200]}") from exc
|
||
|
||
|
||
def _extract_audio_segment(
|
||
input_path: str,
|
||
start_sec: float,
|
||
duration_sec: float,
|
||
output_path: str,
|
||
):
|
||
command = [
|
||
_ffmpeg_bin(),
|
||
"-y",
|
||
"-loglevel",
|
||
"error",
|
||
"-i",
|
||
input_path,
|
||
"-ss",
|
||
f"{start_sec:.3f}",
|
||
"-t",
|
||
f"{duration_sec:.3f}",
|
||
"-vn",
|
||
"-acodec",
|
||
"libmp3lame",
|
||
"-b:a",
|
||
"128k",
|
||
"-ar",
|
||
"44100",
|
||
"-ac",
|
||
"2",
|
||
output_path,
|
||
]
|
||
try:
|
||
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
except FileNotFoundError as exc:
|
||
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
|
||
except subprocess.CalledProcessError as exc:
|
||
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
|
||
raise HTTPException(status_code=502, detail=f"音频分段失败: {stderr[:200]}") from exc
|
||
|
||
|
||
def _pad_audio_segment(
|
||
input_path: str,
|
||
pad_sec: float,
|
||
target_duration_sec: float,
|
||
output_path: str,
|
||
):
|
||
if pad_sec <= 0:
|
||
return
|
||
command = [
|
||
_ffmpeg_bin(),
|
||
"-y",
|
||
"-loglevel",
|
||
"error",
|
||
"-i",
|
||
input_path,
|
||
"-af",
|
||
f"apad=pad_dur={pad_sec:.3f}",
|
||
"-t",
|
||
f"{target_duration_sec:.3f}",
|
||
"-acodec",
|
||
"libmp3lame",
|
||
"-b:a",
|
||
"128k",
|
||
"-ar",
|
||
"44100",
|
||
"-ac",
|
||
"2",
|
||
output_path,
|
||
]
|
||
try:
|
||
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||
except FileNotFoundError as exc:
|
||
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
|
||
except subprocess.CalledProcessError as exc:
|
||
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
|
||
raise HTTPException(status_code=502, detail=f"音频补齐失败: {stderr[:200]}") from exc
|
||
|
||
|
||
def _pick_random_bgm(db: Session) -> SongLibrary:
|
||
base_query = (
|
||
db.query(SongLibrary)
|
||
.filter(
|
||
SongLibrary.status.is_(True),
|
||
SongLibrary.deletetime.is_(None),
|
||
SongLibrary.audio_url.isnot(None),
|
||
SongLibrary.audio_url != "",
|
||
)
|
||
)
|
||
total = base_query.count()
|
||
if total <= 0:
|
||
raise HTTPException(status_code=404, detail="音频库暂无可用音频")
|
||
offset = random.randrange(total)
|
||
song = base_query.offset(offset).limit(1).first()
|
||
if not song or not song.audio_url:
|
||
raise HTTPException(status_code=404, detail="音频库暂无可用音频")
|
||
return song
|
||
|
||
|
||
def _merge_dance_video_with_bgm(
|
||
base_video_url: str,
|
||
audio_url: str,
|
||
target_duration_sec: int,
|
||
) -> tuple[bytes, dict]:
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
base_video_path = os.path.join(tmpdir, "base.mp4")
|
||
audio_source_path = os.path.join(tmpdir, "audio_source.mp3")
|
||
_download_to_path(base_video_url, base_video_path, "跳舞视频")
|
||
resolved_audio_url = _cdnize(audio_url) or audio_url
|
||
_download_to_path(resolved_audio_url, audio_source_path, "BGM音频")
|
||
audio_duration = _probe_media_duration(audio_source_path)
|
||
if not audio_duration:
|
||
raise HTTPException(status_code=502, detail="音频时长解析失败")
|
||
start_sec = 0.0
|
||
clip_duration = min(audio_duration, float(target_duration_sec))
|
||
if audio_duration > target_duration_sec:
|
||
max_start = max(0.0, audio_duration - target_duration_sec)
|
||
start_sec = random.uniform(0.0, max_start) if max_start > 0 else 0.0
|
||
clip_duration = float(target_duration_sec)
|
||
clip_path = os.path.join(tmpdir, "audio_clip.mp3")
|
||
_extract_audio_segment(audio_source_path, start_sec, clip_duration, clip_path)
|
||
if clip_duration < target_duration_sec:
|
||
padded_path = os.path.join(tmpdir, "audio_clip_pad.mp3")
|
||
_pad_audio_segment(
|
||
clip_path,
|
||
target_duration_sec - clip_duration,
|
||
target_duration_sec,
|
||
padded_path,
|
||
)
|
||
clip_path = padded_path
|
||
output_path = os.path.join(tmpdir, "merged.mp4")
|
||
_run_ffmpeg_merge(base_video_path, clip_path, output_path)
|
||
with open(output_path, "rb") as file_handle:
|
||
merged_bytes = file_handle.read()
|
||
return merged_bytes, {
|
||
"bgm_start_sec": round(start_sec, 3),
|
||
"bgm_duration": target_duration_sec,
|
||
}
|
||
|
||
|
||
def _get_or_create_session(db: Session, user: AuthedUser, lover: Lover, session_id: Optional[int]) -> ChatSession:
|
||
if session_id:
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
return session
|
||
|
||
active_sessions = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
|
||
.with_for_update()
|
||
.order_by(ChatSession.created_at.desc())
|
||
.all()
|
||
)
|
||
if active_sessions:
|
||
primary = active_sessions[0]
|
||
for extra in active_sessions[1:]:
|
||
if extra.status == "active":
|
||
extra.status = "archived"
|
||
db.add(extra)
|
||
return primary
|
||
|
||
now = datetime.utcnow()
|
||
session = ChatSession(
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
status="active",
|
||
last_message_at=now,
|
||
created_at=now,
|
||
updated_at=now,
|
||
inner_voice_enabled=False,
|
||
)
|
||
db.add(session)
|
||
db.flush()
|
||
return session
|
||
|
||
|
||
def _next_seq(db: Session, session_id: int) -> int:
|
||
last_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session_id)
|
||
.order_by(ChatMessage.seq.desc())
|
||
.first()
|
||
)
|
||
return (last_msg.seq if last_msg and last_msg.seq else 0) + 1
|
||
|
||
|
||
def _mark_task_failed(db: Session, task_id: int, msg: str):
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not task:
|
||
return
|
||
payload = task.payload or {}
|
||
lover_msg_id = payload.get("lover_message_id")
|
||
session_id = payload.get("session_id")
|
||
if lover_msg_id:
|
||
lover_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.id == lover_msg_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if lover_msg:
|
||
lover_msg.content = f"跳舞视频生成失败:{msg}"
|
||
lover_msg.extra = {
|
||
**(lover_msg.extra or {}),
|
||
"generation_status": "failed",
|
||
"error_msg": msg,
|
||
"generation_task_id": task_id,
|
||
}
|
||
lover_msg.tts_status = lover_msg.tts_status or "pending"
|
||
db.add(lover_msg)
|
||
if session_id:
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if session:
|
||
session.last_message_at = datetime.utcnow()
|
||
db.add(session)
|
||
task.status = "failed"
|
||
task.error_msg = msg[:255]
|
||
task.updated_at = datetime.utcnow()
|
||
db.add(task)
|
||
db.commit()
|
||
|
||
|
||
def _process_dance_task(task_id: int):
|
||
# Phase 1: 读取任务并标记 running,快速退出连接
|
||
task = None
|
||
prompt = ""
|
||
image_url = ""
|
||
session_id = None
|
||
user_message_id = None
|
||
lover_message_id = None
|
||
dash_task_id = None
|
||
user_id = None
|
||
lover_id = None
|
||
bgm_song_id = None
|
||
bgm_audio_url_raw = None
|
||
bgm_audio_url = None
|
||
bgm_start_sec = None
|
||
try:
|
||
db = SessionLocal()
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not task or task.status in ("succeeded", "failed"):
|
||
db.rollback()
|
||
return
|
||
|
||
user_id = task.user_id
|
||
lover_id = task.lover_id
|
||
payload = task.payload or {}
|
||
prompt = payload.get("prompt") or ""
|
||
image_url = payload.get("image_url")
|
||
session_id = payload.get("session_id")
|
||
user_message_id = payload.get("user_message_id")
|
||
lover_message_id = payload.get("lover_message_id")
|
||
dash_task_id = payload.get("dashscope_task_id")
|
||
|
||
# 基础存在性校验,快速失败
|
||
lover = db.query(Lover).filter(Lover.id == lover_id).first()
|
||
user_row = db.query(User).filter(User.id == user_id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人不存在,请重新创建")
|
||
if not lover.image_url:
|
||
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
if (user_row.video_gen_remaining or 0) <= 0:
|
||
raise HTTPException(status_code=400, detail="视频生成次数不足")
|
||
image_url = image_url or lover.image_url
|
||
|
||
task.status = "running"
|
||
task.updated_at = datetime.utcnow()
|
||
db.add(task)
|
||
db.commit()
|
||
except HTTPException as exc:
|
||
try:
|
||
_mark_task_failed(
|
||
db,
|
||
task_id,
|
||
str(exc.detail) if hasattr(exc, "detail") else str(exc),
|
||
)
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
try:
|
||
db.close()
|
||
except Exception:
|
||
pass
|
||
return
|
||
except Exception as exc:
|
||
try:
|
||
_mark_task_failed(db, task_id, str(exc)[:255])
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
try:
|
||
db.close()
|
||
except Exception:
|
||
pass
|
||
return
|
||
finally:
|
||
try:
|
||
db.close()
|
||
except Exception:
|
||
pass
|
||
|
||
# Phase 2: 提交/轮询 DashScope(无 DB 连接)
|
||
try:
|
||
if not dash_task_id:
|
||
dash_task_id = _submit_video_task(prompt, image_url)
|
||
with SessionLocal() as db_tmp:
|
||
task_row = (
|
||
db_tmp.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if task_row:
|
||
task_row.payload = {**(task_row.payload or {}), "dashscope_task_id": dash_task_id}
|
||
task_row.status = "running"
|
||
task_row.updated_at = datetime.utcnow()
|
||
db_tmp.add(task_row)
|
||
db_tmp.commit()
|
||
|
||
dash_video_url = _poll_video_url(dash_task_id)
|
||
with SessionLocal() as db_tmp:
|
||
bgm_song = _pick_random_bgm(db_tmp)
|
||
bgm_song_id = bgm_song.id
|
||
bgm_audio_url_raw = bgm_song.audio_url
|
||
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
|
||
task_row = (
|
||
db_tmp.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if task_row:
|
||
task_row.payload = {
|
||
**(task_row.payload or {}),
|
||
"bgm_song_id": bgm_song_id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_audio_url_raw": bgm_audio_url_raw,
|
||
}
|
||
task_row.updated_at = datetime.utcnow()
|
||
db_tmp.add(task_row)
|
||
db_tmp.commit()
|
||
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
|
||
dash_video_url,
|
||
bgm_audio_url,
|
||
DANCE_TARGET_DURATION_SEC,
|
||
)
|
||
bgm_start_sec = bgm_meta.get("bgm_start_sec")
|
||
safe_prompt_tag = "prompt"
|
||
object_name = f"lover/{lover_id}/dance/{int(time.time())}_{safe_prompt_tag}.mp4"
|
||
oss_url = _upload_to_oss(merged_bytes, object_name)
|
||
except Exception as exc:
|
||
try:
|
||
db_fail = SessionLocal()
|
||
_mark_task_failed(db_fail, task_id, str(exc) if not hasattr(exc, "detail") else str(exc.detail))
|
||
except Exception:
|
||
pass
|
||
return
|
||
|
||
# Phase 3: 回写结果与消息(短事务)
|
||
db = SessionLocal()
|
||
try:
|
||
task_row = (
|
||
db.query(GenerationTask)
|
||
.filter(GenerationTask.id == task_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not task_row:
|
||
db.rollback()
|
||
return
|
||
|
||
lover = db.query(Lover).filter(Lover.id == task_row.lover_id).first()
|
||
user_row = (
|
||
db.query(User)
|
||
.filter(User.id == task_row.user_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not lover or not user_row:
|
||
raise HTTPException(status_code=404, detail="用户或恋人不存在")
|
||
|
||
# 获取/创建会话
|
||
session = None
|
||
if session_id:
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id, ChatSession.user_id == user_row.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not session:
|
||
session = _get_or_create_session(
|
||
db,
|
||
AuthedUser(
|
||
id=user_row.id,
|
||
reg_step=user_row.reg_step or 0,
|
||
gender=user_row.gender or 0,
|
||
nickname=user_row.nickname or user_row.username or "",
|
||
token=user_row.token or "",
|
||
),
|
||
lover,
|
||
None,
|
||
)
|
||
|
||
# 用户消息
|
||
user_msg = None
|
||
if user_message_id:
|
||
user_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.id == user_message_id, ChatMessage.session_id == session.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not user_msg:
|
||
now = datetime.utcnow()
|
||
next_seq = _next_seq(db, session.id)
|
||
user_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user_row.id,
|
||
lover_id=lover.id,
|
||
role="user",
|
||
content_type="text",
|
||
content=prompt,
|
||
seq=next_seq,
|
||
created_at=now,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
)
|
||
db.add(user_msg)
|
||
db.flush()
|
||
next_seq = user_msg.seq or _next_seq(db, session.id)
|
||
|
||
# 恋人消息
|
||
lover_msg = None
|
||
if lover_message_id:
|
||
lover_msg = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.id == lover_message_id, ChatMessage.session_id == session.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not lover_msg:
|
||
lover_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user_row.id,
|
||
lover_id=lover.id,
|
||
role="lover",
|
||
content_type="text",
|
||
content="正在为你生成跳舞视频,完成后会自动更新此消息",
|
||
seq=next_seq + 1,
|
||
created_at=datetime.utcnow(),
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
extra={
|
||
"generation_task_id": task_row.id,
|
||
"generation_status": "pending",
|
||
"prompt": prompt,
|
||
},
|
||
tts_status="pending",
|
||
)
|
||
db.add(lover_msg)
|
||
db.flush()
|
||
|
||
lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{oss_url}"
|
||
lover_msg.extra = {
|
||
**(lover_msg.extra or {}),
|
||
"video_url": oss_url,
|
||
"dashscope_video_url": dash_video_url,
|
||
"dashscope_task_id": dash_task_id,
|
||
"generation_task_id": task_row.id,
|
||
"prompt": prompt,
|
||
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
|
||
"resolution": settings.VIDEO_GEN_RESOLUTION or "480P",
|
||
"duration": DANCE_TARGET_DURATION_SEC,
|
||
"base_duration": settings.VIDEO_GEN_DURATION or 5,
|
||
"watermark": False,
|
||
"generation_status": "succeeded",
|
||
"bgm_song_id": bgm_song_id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_audio_url_raw": bgm_audio_url_raw,
|
||
"bgm_start_sec": bgm_start_sec,
|
||
"bgm_duration": DANCE_TARGET_DURATION_SEC,
|
||
}
|
||
lover_msg.tts_status = lover_msg.tts_status or "pending"
|
||
db.add(lover_msg)
|
||
|
||
# 扣减额度并刷新会话时间(不允许扣成负数,只扣一次)
|
||
already_deducted = (task_row.payload or {}).get("deducted")
|
||
remaining = user_row.video_gen_remaining or 0
|
||
if remaining > 0 and not already_deducted:
|
||
user_row.video_gen_remaining = remaining - 1
|
||
session.last_message_at = datetime.utcnow()
|
||
db.add(user_row)
|
||
db.add(session)
|
||
|
||
task_row.status = "succeeded"
|
||
task_row.result_url = oss_url
|
||
task_row.payload = {
|
||
**(task_row.payload or {}),
|
||
"dashscope_video_url": dash_video_url,
|
||
"dashscope_task_id": dash_task_id,
|
||
"session_id": session.id,
|
||
"user_message_id": user_msg.id,
|
||
"lover_message_id": lover_msg.id,
|
||
"deducted": True,
|
||
"output_duration": DANCE_TARGET_DURATION_SEC,
|
||
"bgm_song_id": bgm_song_id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_audio_url_raw": bgm_audio_url_raw,
|
||
"bgm_start_sec": bgm_start_sec,
|
||
"bgm_duration": DANCE_TARGET_DURATION_SEC,
|
||
}
|
||
task_row.updated_at = datetime.utcnow()
|
||
db.add(task_row)
|
||
db.commit()
|
||
except HTTPException as exc:
|
||
db.rollback()
|
||
try:
|
||
_mark_task_failed(
|
||
db,
|
||
task_id,
|
||
str(exc.detail) if hasattr(exc, "detail") else str(exc),
|
||
)
|
||
except Exception:
|
||
pass
|
||
except Exception as exc:
|
||
db.rollback()
|
||
try:
|
||
_mark_task_failed(db, task_id, str(exc)[:255])
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@router.post("/generate", response_model=ApiResponse[DanceTaskStatusOut])
|
||
def generate_dance_video(
|
||
payload: DanceGenerateIn,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
prompt = (payload.prompt or "").strip()
|
||
if not prompt:
|
||
raise HTTPException(status_code=400, detail="请输入想要跳的舞/唱的歌/动作描述")
|
||
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
|
||
if not lover:
|
||
raise HTTPException(status_code=404, detail="恋人不存在,请先完成创建流程")
|
||
if not lover.image_url:
|
||
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
|
||
|
||
user_row = (
|
||
db.query(User)
|
||
.filter(User.id == user.id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if not user_row:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
if (user_row.video_gen_remaining or 0) <= 0:
|
||
raise HTTPException(status_code=400, detail="视频生成次数不足")
|
||
|
||
pending_task = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
GenerationTask.status.in_(["pending", "running"]),
|
||
)
|
||
.first()
|
||
)
|
||
if pending_task:
|
||
raise HTTPException(status_code=409, detail="已有视频生成任务进行中,请稍后再试")
|
||
|
||
idem_key_src = f"video:{user.id}:{lover.image_url}:{prompt}"
|
||
idem_key = hashlib.sha256(idem_key_src.encode("utf-8")).hexdigest()
|
||
|
||
task = GenerationTask(
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
task_type="video",
|
||
status="pending",
|
||
idempotency_key=idem_key,
|
||
payload={
|
||
"image_url": lover.image_url,
|
||
"prompt": prompt,
|
||
"model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
|
||
"resolution": settings.VIDEO_GEN_RESOLUTION or "480P",
|
||
"duration": settings.VIDEO_GEN_DURATION or 5,
|
||
"watermark": False,
|
||
"output_duration": DANCE_TARGET_DURATION_SEC,
|
||
},
|
||
created_at=datetime.utcnow(),
|
||
updated_at=datetime.utcnow(),
|
||
)
|
||
db.add(task)
|
||
try:
|
||
db.flush()
|
||
except IntegrityError:
|
||
db.rollback()
|
||
existing = (
|
||
db.query(GenerationTask)
|
||
.filter(GenerationTask.idempotency_key == idem_key)
|
||
.first()
|
||
)
|
||
if existing and existing.status in ("pending", "running"):
|
||
raise HTTPException(status_code=409, detail="已提交相同的视频生成任务,请稍后查看结果")
|
||
# 既往任务已结束,避免幂等键冲突,生成新的短哈希重试
|
||
retry_key = hashlib.sha256(f"{idem_key}:{time.time()}".encode()).hexdigest()
|
||
task.idempotency_key = retry_key
|
||
db.add(task)
|
||
db.flush()
|
||
|
||
# 预留会话与消息序号,便于生成完成后不会被后续消息“插队”
|
||
session = _get_or_create_session(db, user, lover, None)
|
||
now = datetime.utcnow()
|
||
next_seq = _next_seq(db, session.id)
|
||
user_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="user",
|
||
content_type="text",
|
||
content=prompt,
|
||
seq=next_seq,
|
||
created_at=now,
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
)
|
||
db.add(user_msg)
|
||
db.flush()
|
||
|
||
lover_msg = ChatMessage(
|
||
session_id=session.id,
|
||
user_id=user.id,
|
||
lover_id=lover.id,
|
||
role="lover",
|
||
content_type="text",
|
||
content="正在为你生成跳舞视频,完成后会自动更新此消息",
|
||
seq=next_seq + 1,
|
||
created_at=datetime.utcnow(),
|
||
model=settings.LLM_MODEL or "qwen-flash",
|
||
extra={
|
||
"generation_task_id": task.id,
|
||
"generation_status": "pending",
|
||
"prompt": prompt,
|
||
},
|
||
tts_status="pending",
|
||
)
|
||
db.add(lover_msg)
|
||
db.flush()
|
||
|
||
session.last_message_at = datetime.utcnow()
|
||
db.add(session)
|
||
|
||
task.payload = {
|
||
**(task.payload or {}),
|
||
"session_id": session.id,
|
||
"user_message_id": user_msg.id,
|
||
"lover_message_id": lover_msg.id,
|
||
}
|
||
db.add(task)
|
||
|
||
db.commit()
|
||
background_tasks.add_task(_process_dance_task, task.id)
|
||
|
||
return success_response(
|
||
DanceTaskStatusOut(
|
||
generation_task_id=task.id,
|
||
status="pending",
|
||
dashscope_task_id="",
|
||
video_url="",
|
||
session_id=session.id,
|
||
user_message_id=user_msg.id,
|
||
lover_message_id=lover_msg.id,
|
||
error_msg=None,
|
||
),
|
||
msg="视频生成任务已提交,正在生成",
|
||
)
|
||
|
||
|
||
@router.get("/generate/{task_id}", response_model=ApiResponse[DanceTaskStatusOut])
|
||
def get_dance_task(
|
||
task_id: int,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db),
|
||
user: AuthedUser = Depends(get_current_user),
|
||
):
|
||
task = (
|
||
db.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.id == task_id,
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
)
|
||
.first()
|
||
)
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
status_msg_map = {
|
||
"pending": "视频生成中",
|
||
"running": "视频生成中",
|
||
"succeeded": "视频生成成功",
|
||
"failed": "视频生成失败",
|
||
}
|
||
resp_msg = status_msg_map.get(task.status or "", "查询成功")
|
||
# 若状态仍为 pending/running,补偿性触发处理(同步 + 异步双保险)
|
||
if task.status in ("pending", "running"):
|
||
# 避免阻塞,优先异步处理
|
||
background_tasks.add_task(_process_dance_task, task.id)
|
||
# 用新会话读取最新状态,避免当前事务快照阻塞
|
||
with SessionLocal() as tmp:
|
||
latest = (
|
||
tmp.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.id == task.id,
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
)
|
||
.first()
|
||
)
|
||
if latest:
|
||
task = latest
|
||
# 若仍然卡在 running,直接查询 dashscope,并回填结果/失败,避免无限等待
|
||
if task.status in ("pending", "running") and (task.payload or {}).get("dashscope_task_id"):
|
||
status, dash_video = _fetch_dashscope_status((task.payload or {}).get("dashscope_task_id"))
|
||
with SessionLocal() as tmp:
|
||
current = (
|
||
tmp.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.id == task.id,
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if current:
|
||
if status == "SUCCEEDED":
|
||
dash_url = (
|
||
dash_video
|
||
or (current.payload or {}).get("dashscope_video_url")
|
||
or current.result_url
|
||
)
|
||
final_url = current.result_url or dash_url
|
||
bgm_song_id = (current.payload or {}).get("bgm_song_id")
|
||
bgm_audio_url = (current.payload or {}).get("bgm_audio_url")
|
||
bgm_start_sec = (current.payload or {}).get("bgm_start_sec")
|
||
merge_failed = False
|
||
if dash_url and not _is_own_oss_url(final_url or ""):
|
||
try:
|
||
bgm_song = _pick_random_bgm(tmp)
|
||
bgm_song_id = bgm_song.id
|
||
bgm_audio_url_raw = bgm_song.audio_url
|
||
bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw
|
||
merged_bytes, bgm_meta = _merge_dance_video_with_bgm(
|
||
dash_url,
|
||
bgm_audio_url,
|
||
DANCE_TARGET_DURATION_SEC,
|
||
)
|
||
bgm_start_sec = bgm_meta.get("bgm_start_sec")
|
||
object_name = f"lover/{current.lover_id}/dance/{int(time.time())}_prompt.mp4"
|
||
final_url = _upload_to_oss(merged_bytes, object_name)
|
||
current.payload = {
|
||
**(current.payload or {}),
|
||
"bgm_song_id": bgm_song_id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_audio_url_raw": bgm_audio_url_raw,
|
||
}
|
||
except Exception as exc:
|
||
current.status = "failed"
|
||
current.error_msg = str(exc) if not hasattr(exc, "detail") else str(exc.detail)
|
||
current.updated_at = datetime.utcnow()
|
||
tmp.add(current)
|
||
tmp.commit()
|
||
task = current
|
||
merge_failed = True
|
||
if not merge_failed:
|
||
current.status = "succeeded"
|
||
current.result_url = final_url
|
||
current.payload = {
|
||
**(current.payload or {}),
|
||
"dashscope_video_url": dash_url or (current.payload or {}).get("dashscope_video_url"),
|
||
"output_duration": DANCE_TARGET_DURATION_SEC,
|
||
"bgm_song_id": bgm_song_id,
|
||
"bgm_audio_url": bgm_audio_url,
|
||
"bgm_start_sec": bgm_start_sec,
|
||
"bgm_duration": DANCE_TARGET_DURATION_SEC,
|
||
}
|
||
current.updated_at = datetime.utcnow()
|
||
# 尝试同步更新聊天占位消息与额度(避免重复扣,检查 deducted 标记)
|
||
try:
|
||
lover = tmp.query(Lover).filter(Lover.id == current.lover_id).first()
|
||
user_row = (
|
||
tmp.query(User)
|
||
.filter(User.id == current.user_id)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
session = None
|
||
if current.payload.get("session_id"):
|
||
session = (
|
||
tmp.query(ChatSession)
|
||
.filter(
|
||
ChatSession.id == current.payload.get("session_id"),
|
||
ChatSession.user_id == current.user_id,
|
||
)
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
lover_msg = None
|
||
if current.payload.get("lover_message_id"):
|
||
lover_msg = (
|
||
tmp.query(ChatMessage)
|
||
.filter(ChatMessage.id == current.payload.get("lover_message_id"))
|
||
.with_for_update()
|
||
.first()
|
||
)
|
||
if lover_msg:
|
||
lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{final_url}"
|
||
lover_msg.extra = {
|
||
**(lover_msg.extra or {}),
|
||
"video_url": final_url,
|
||
"dashscope_video_url": dash_url,
|
||
"dashscope_task_id": current.payload.get("dashscope_task_id"),
|
||
"generation_task_id": current.id,
|
||
"prompt": current.payload.get("prompt"),
|
||
"model": current.payload.get("model") or settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash",
|
||
"resolution": current.payload.get("resolution") or settings.VIDEO_GEN_RESOLUTION or "480P",
|
||
"duration": current.payload.get("output_duration") or DANCE_TARGET_DURATION_SEC,
|
||
"base_duration": current.payload.get("duration") or settings.VIDEO_GEN_DURATION or 5,
|
||
"watermark": current.payload.get("watermark", False),
|
||
"generation_status": "succeeded",
|
||
"bgm_song_id": current.payload.get("bgm_song_id"),
|
||
"bgm_audio_url": current.payload.get("bgm_audio_url"),
|
||
"bgm_audio_url_raw": current.payload.get("bgm_audio_url_raw"),
|
||
"bgm_start_sec": current.payload.get("bgm_start_sec"),
|
||
"bgm_duration": current.payload.get("bgm_duration") or DANCE_TARGET_DURATION_SEC,
|
||
}
|
||
lover_msg.tts_status = lover_msg.tts_status or "pending"
|
||
tmp.add(lover_msg)
|
||
if user_row and not (current.payload or {}).get("deducted"):
|
||
remaining = user_row.video_gen_remaining or 0
|
||
if remaining > 0:
|
||
user_row.video_gen_remaining = remaining - 1
|
||
current.payload = {**(current.payload or {}), "deducted": True}
|
||
tmp.add(user_row)
|
||
if session:
|
||
session.last_message_at = datetime.utcnow()
|
||
tmp.add(session)
|
||
except Exception:
|
||
# 避免阻断主流程
|
||
pass
|
||
tmp.add(current)
|
||
tmp.commit()
|
||
task = current
|
||
elif status == "FAILED":
|
||
current.status = "failed"
|
||
current.error_msg = "视频生成失败(DashScope 返回失败)"
|
||
current.updated_at = datetime.utcnow()
|
||
tmp.add(current)
|
||
tmp.commit()
|
||
task = current
|
||
|
||
# 用新会话读取最终状态,避免 DetachedInstanceError
|
||
with SessionLocal() as snap:
|
||
latest = (
|
||
snap.query(GenerationTask)
|
||
.filter(
|
||
GenerationTask.id == task_id,
|
||
GenerationTask.user_id == user.id,
|
||
GenerationTask.task_type == "video",
|
||
)
|
||
.first()
|
||
)
|
||
if not latest:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
payload = latest.payload or {}
|
||
resp_msg = status_msg_map.get(latest.status or "", resp_msg)
|
||
return success_response(
|
||
DanceTaskStatusOut(
|
||
generation_task_id=latest.id,
|
||
status=latest.status,
|
||
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
|
||
video_url=latest.result_url or payload.get("dashscope_video_url") or "",
|
||
session_id=int(payload.get("session_id") or 0),
|
||
user_message_id=int(payload.get("user_message_id") or 0),
|
||
lover_message_id=int(payload.get("lover_message_id") or 0),
|
||
error_msg=latest.error_msg,
|
||
),
|
||
msg=resp_msg,
|
||
)
|