import hashlib import os import random import shutil import subprocess import tempfile import time from datetime import datetime from typing import Optional import oss2 import requests from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError from ..config import settings from ..db import SessionLocal, get_db from ..deps import AuthedUser, get_current_user from ..models import ( ChatMessage, ChatSession, GenerationTask, Lover, SongLibrary, User, ) from ..response import ApiResponse, success_response try: import imageio_ffmpeg # type: ignore except Exception: # pragma: no cover imageio_ffmpeg = None router = APIRouter(prefix="/dance", tags=["dance"]) def _ffmpeg_bin() -> str: found = shutil.which("ffmpeg") if found: return found if imageio_ffmpeg is not None: try: return imageio_ffmpeg.get_ffmpeg_exe() except Exception: pass return "ffmpeg" def _ffprobe_bin() -> str: found = shutil.which("ffprobe") if found: return found if imageio_ffmpeg is not None: try: ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe.exe") if os.path.exists(candidate): return candidate candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe") if os.path.exists(candidate): return candidate except Exception: pass return "ffprobe" DANCE_TARGET_DURATION_SEC = 10 @router.get("/history", response_model=ApiResponse[list[dict]]) def get_dance_history( db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), page: int = 1, size: int = 20, ): """ 获取用户的跳舞视频历史记录。 """ lover = db.query(Lover).filter(Lover.user_id == user.id).first() if not lover: raise HTTPException(status_code=404, detail="恋人未找到") offset = (page - 1) * size tasks = ( db.query(GenerationTask) .filter( GenerationTask.user_id == user.id, GenerationTask.lover_id == lover.id, GenerationTask.task_type == "video", GenerationTask.payload["prompt"].as_string().isnot(None), GenerationTask.status == "succeeded", GenerationTask.result_url.isnot(None), ) .order_by(GenerationTask.id.desc()) .offset(offset) .limit(size) .all() ) result: list[dict] = [] for task in tasks: payload = task.payload or {} result.append( { "id": task.id, "prompt": payload.get("prompt") or "", "video_url": _cdnize(task.result_url) or "", "created_at": task.created_at.isoformat() if task.created_at else None, } ) return success_response(result, msg="获取成功") def _download_binary(url: str) -> bytes: try: resp = requests.get(url, timeout=30) except Exception as exc: raise HTTPException(status_code=502, detail="文件下载失败") from exc if resp.status_code != 200: raise HTTPException(status_code=502, detail="文件下载失败") return resp.content def _retry_finalize_dance_task(task_id: int) -> None: """任务失败但 DashScope 端已成功时,尝试重新下载视频并重新上传(自愈/手动重试共用)。""" try: with SessionLocal() as db: task = ( db.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if not task: return payload = task.payload or {} dash_id = payload.get("dashscope_task_id") if not dash_id: return status, dash_video_url = _fetch_dashscope_status(str(dash_id)) if status != "SUCCEEDED" or not dash_video_url: return # 重新生成:下载 dashscope 视频 -> 随机 BGM -> 合成 -> 上传 bgm_song = _pick_random_bgm(db) bgm_audio_url_raw = bgm_song.audio_url bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw merged_bytes, bgm_meta = _merge_dance_video_with_bgm( dash_video_url, bgm_audio_url, DANCE_TARGET_DURATION_SEC, ) object_name = f"lover/{task.lover_id}/dance/{int(time.time())}_retry.mp4" oss_url = _upload_to_oss(merged_bytes, object_name) task.status = "succeeded" task.result_url = oss_url task.error_msg = None task.payload = { **payload, "dashscope_video_url": dash_video_url, "bgm_song_id": bgm_song.id, "bgm_audio_url": bgm_audio_url, "bgm_audio_url_raw": bgm_audio_url_raw, "bgm_start_sec": bgm_meta.get("bgm_start_sec"), "bgm_duration": DANCE_TARGET_DURATION_SEC, "retry_finalized": True, } task.updated_at = datetime.utcnow() db.add(task) db.commit() except Exception: return @router.post("/retry/{task_id}", response_model=ApiResponse[dict]) def retry_dance_task( task_id: int, background_tasks: BackgroundTasks, db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), ): """手动重试:用于 DashScope 端已成功但本地下载/合成/上传失败导致任务失败的情况。""" lover = db.query(Lover).filter(Lover.user_id == user.id).first() if not lover: raise HTTPException(status_code=404, detail="恋人未找到") task = ( db.query(GenerationTask) .filter( GenerationTask.id == task_id, GenerationTask.user_id == user.id, GenerationTask.lover_id == lover.id, GenerationTask.task_type == "video", GenerationTask.payload["prompt"].as_string().isnot(None), ) .first() ) if not task: raise HTTPException(status_code=404, detail="任务不存在") payload = task.payload or {} dash_id = payload.get("dashscope_task_id") if not dash_id: raise HTTPException(status_code=400, detail="任务缺少 dashscope_task_id,无法重试") task.payload = {**payload, "manual_retry": True} task.updated_at = datetime.utcnow() db.add(task) db.commit() background_tasks.add_task(_retry_finalize_dance_task, int(task.id)) return success_response({"task_id": int(task.id)}, msg="已触发重试") @router.get("/history/all", response_model=ApiResponse[list[dict]]) def get_dance_history_all( db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), page: int = 1, size: int = 20, ): """获取用户的跳舞视频全部历史记录(成功+失败+进行中)。""" lover = db.query(Lover).filter(Lover.user_id == user.id).first() if not lover: raise HTTPException(status_code=404, detail="恋人未找到") offset = (page - 1) * size tasks = ( db.query(GenerationTask) .filter( GenerationTask.user_id == user.id, GenerationTask.lover_id == lover.id, GenerationTask.task_type == "video", GenerationTask.payload["prompt"].as_string().isnot(None), ) .order_by(GenerationTask.id.desc()) .offset(offset) .limit(size) .all() ) result: list[dict] = [] for task in tasks: payload = task.payload or {} url = task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or "" result.append( { "id": int(task.id), "prompt": payload.get("prompt") or "", "status": task.status, "video_url": _cdnize(url) or "", "error_msg": task.error_msg, "created_at": task.created_at.isoformat() if task.created_at else None, } ) return success_response(result, msg="获取成功") class DanceGenerateIn(BaseModel): prompt: str = Field(..., min_length=2, max_length=400, description="用户希望跳的舞/动作描述") class DanceTaskStatusOut(BaseModel): generation_task_id: int status: str = Field(..., description="pending|running|succeeded|failed") dashscope_task_id: str = "" video_url: str = "" session_id: int = 0 user_message_id: int = 0 lover_message_id: int = 0 error_msg: Optional[str] = None @router.get("/current", response_model=ApiResponse[Optional[DanceTaskStatusOut]]) def get_current_dance_task( db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), ): task = ( db.query(GenerationTask) .filter( GenerationTask.user_id == user.id, GenerationTask.task_type == "video", GenerationTask.status.in_(["pending", "running"]), GenerationTask.payload["prompt"].as_string().isnot(None), ) .order_by(GenerationTask.id.desc()) .first() ) if not task: return success_response(None, msg="暂无进行中的任务") payload = task.payload or {} return success_response( DanceTaskStatusOut( generation_task_id=task.id, status=task.status, dashscope_task_id=str(payload.get("dashscope_task_id") or ""), video_url=task.result_url or payload.get("video_url") or payload.get("dashscope_video_url") or "", session_id=int(payload.get("session_id") or 0), user_message_id=int(payload.get("user_message_id") or 0), lover_message_id=int(payload.get("lover_message_id") or 0), error_msg=task.error_msg, ), msg="获取成功", ) def _upload_to_oss(file_bytes: bytes, object_name: str) -> str: if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET: raise HTTPException(status_code=500, detail="未配置 OSS Key") if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT: raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint") auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET) endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/") bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME) bucket.put_object(object_name, file_bytes) cdn = settings.ALIYUN_OSS_CDN_DOMAIN if cdn: return f"{cdn.rstrip('/')}/{object_name}" return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}" def _is_own_oss_url(url: str) -> bool: if not url: return False cdn = settings.ALIYUN_OSS_CDN_DOMAIN if cdn and url.startswith(cdn.rstrip("/")): return True bucket = settings.ALIYUN_OSS_BUCKET_NAME endpoint = (settings.ALIYUN_OSS_ENDPOINT or "").rstrip("/") if bucket and endpoint: domain = endpoint.replace("https://", "").replace("http://", "") return url.startswith(f"https://{bucket}.{domain}/") return False def _cdnize(url: Optional[str]) -> Optional[str]: """ 将相对路径补全为可访问 URL。优先使用 CDN,其次 bucket+endpoint,最后兜底固定域名。 """ if not url: return url cleaned = url.strip() if cleaned.startswith("http://") or cleaned.startswith("https://"): return cleaned cleaned = cleaned.lstrip("/") if settings.ALIYUN_OSS_CDN_DOMAIN: return f"{settings.ALIYUN_OSS_CDN_DOMAIN.rstrip('/')}/{cleaned}" if settings.ALIYUN_OSS_BUCKET_NAME and settings.ALIYUN_OSS_ENDPOINT: endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/").replace("https://", "").replace("http://", "") return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint}/{cleaned}" return f"https://nvlovers.oss-cn-qingdao.aliyuncs.com/{cleaned}" def _shorten_url(url: str, max_len: int = 160) -> str: if len(url) <= max_len: return url return f"{url[: max_len - 3]}..." def _submit_video_task(prompt: str, image_url: str) -> str: if not settings.DASHSCOPE_API_KEY: raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY") resolution = settings.VIDEO_GEN_RESOLUTION or "480P" duration = settings.VIDEO_GEN_DURATION or 5 payload = { "model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash", "input": { "prompt": prompt, "img_url": image_url, }, "parameters": { "resolution": resolution, "duration": duration, "watermark": False, "prompt_extend": True, }, } headers = { "X-DashScope-Async": "enable", "Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}", "Content-Type": "application/json", } try: resp = requests.post( "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis", headers=headers, json=payload, timeout=10, ) except Exception as exc: raise HTTPException(status_code=502, detail="调用视频生成接口失败") from exc if resp.status_code != 200: msg = resp.text try: msg = resp.json().get("message") or msg except Exception: pass raise HTTPException(status_code=502, detail=f"视频任务提交失败: {msg}") try: data = resp.json() except Exception as exc: raise HTTPException(status_code=502, detail="视频任务返回解析失败") from exc task_id = ( data.get("output", {}).get("task_id") or data.get("task_id") or data.get("output", {}).get("id") ) if not task_id: raise HTTPException(status_code=502, detail="视频任务未返回 task_id") return str(task_id) def _poll_video_url(task_id: str) -> str: headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"} query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" deadline = time.time() + 120 while time.time() < deadline: time.sleep(3) try: resp = requests.get(query_url, headers=headers, timeout=8) except Exception: continue if resp.status_code != 200: continue try: data = resp.json() except Exception: continue output = data.get("output") or {} status_str = str( output.get("task_status") or data.get("task_status") or data.get("status") or "" ).upper() if status_str == "SUCCEEDED": url = output.get("video_url") or data.get("video_url") if not url: raise HTTPException(status_code=502, detail="视频生成成功但未返回结果 URL") return url if status_str == "FAILED": msg = output.get("message") or data.get("message") or "生成失败" raise HTTPException(status_code=502, detail=f"视频生成失败: {msg}") raise HTTPException(status_code=504, detail="视频生成超时,请稍后重试") def _fetch_dashscope_status(task_id: str) -> tuple[str, Optional[str]]: """ 直接查询 dashscope 任务状态。 返回 (status, video_url|None);status 取 SUCCEEDED/FAILED/其他。 """ headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"} query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" try: resp = requests.get(query_url, headers=headers, timeout=8) except Exception: return "UNKNOWN", None if resp.status_code != 200: return "UNKNOWN", None try: data = resp.json() except Exception: return "UNKNOWN", None output = data.get("output") or {} status_str = str( output.get("task_status") or data.get("task_status") or data.get("status") or "" ).upper() video_url = output.get("video_url") or data.get("video_url") return status_str, video_url def _download_to_path(url: str, target_path: str, label: str): if not url: raise HTTPException(status_code=502, detail=f"{label}下载失败: 空URL") try: resp = requests.get(url, stream=True, timeout=30) except Exception as exc: raise HTTPException(status_code=502, detail=f"{label}下载失败: {_shorten_url(url)}") from exc if resp.status_code != 200: raise HTTPException(status_code=502, detail=f"{label}下载失败({resp.status_code}): {_shorten_url(url)}") try: with open(target_path, "wb") as file_handle: for chunk in resp.iter_content(chunk_size=1024 * 1024): if chunk: file_handle.write(chunk) finally: resp.close() def _probe_media_duration(path: str) -> Optional[float]: command = [ _ffprobe_bin(), "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", path, ] try: result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except FileNotFoundError: return None except subprocess.CalledProcessError: return None raw = result.stdout.decode("utf-8", errors="ignore").strip() if not raw: return None try: duration = float(raw) except ValueError: return None if duration <= 0: return None return duration def _run_ffmpeg_merge(video_path: str, audio_path: str, output_path: str): audio_duration = _probe_media_duration(audio_path) command = [ _ffmpeg_bin(), "-y", "-loglevel", "error", "-stream_loop", "-1", "-i", video_path, "-i", audio_path, ] if audio_duration: command.extend(["-t", f"{audio_duration:.3f}"]) command += [ "-map", "0:v:0", "-map", "1:a:0", "-c:v", "libx264", "-preset", "veryfast", "-crf", "23", "-pix_fmt", "yuv420p", "-c:a", "aac", "-shortest", "-movflags", "+faststart", output_path, ] try: subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except FileNotFoundError as exc: raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else "" raise HTTPException(status_code=502, detail=f"ffmpeg 合成失败: {stderr[:200]}") from exc def _extract_audio_segment( input_path: str, start_sec: float, duration_sec: float, output_path: str, ): command = [ _ffmpeg_bin(), "-y", "-loglevel", "error", "-i", input_path, "-ss", f"{start_sec:.3f}", "-t", f"{duration_sec:.3f}", "-vn", "-acodec", "libmp3lame", "-b:a", "128k", "-ar", "44100", "-ac", "2", output_path, ] try: subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except FileNotFoundError as exc: raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else "" raise HTTPException(status_code=502, detail=f"音频分段失败: {stderr[:200]}") from exc def _pad_audio_segment( input_path: str, pad_sec: float, target_duration_sec: float, output_path: str, ): if pad_sec <= 0: return command = [ _ffmpeg_bin(), "-y", "-loglevel", "error", "-i", input_path, "-af", f"apad=pad_dur={pad_sec:.3f}", "-t", f"{target_duration_sec:.3f}", "-acodec", "libmp3lame", "-b:a", "128k", "-ar", "44100", "-ac", "2", output_path, ] try: subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except FileNotFoundError as exc: raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else "" raise HTTPException(status_code=502, detail=f"音频补齐失败: {stderr[:200]}") from exc def _pick_random_bgm(db: Session) -> SongLibrary: base_query = ( db.query(SongLibrary) .filter( SongLibrary.status.is_(True), SongLibrary.deletetime.is_(None), SongLibrary.audio_url.isnot(None), SongLibrary.audio_url != "", ) ) total = base_query.count() if total <= 0: raise HTTPException(status_code=404, detail="音频库暂无可用音频") offset = random.randrange(total) song = base_query.offset(offset).limit(1).first() if not song or not song.audio_url: raise HTTPException(status_code=404, detail="音频库暂无可用音频") return song def _merge_dance_video_with_bgm( base_video_url: str, audio_url: str, target_duration_sec: int, ) -> tuple[bytes, dict]: with tempfile.TemporaryDirectory() as tmpdir: base_video_path = os.path.join(tmpdir, "base.mp4") audio_source_path = os.path.join(tmpdir, "audio_source.mp3") _download_to_path(base_video_url, base_video_path, "跳舞视频") resolved_audio_url = _cdnize(audio_url) or audio_url _download_to_path(resolved_audio_url, audio_source_path, "BGM音频") audio_duration = _probe_media_duration(audio_source_path) if not audio_duration: raise HTTPException(status_code=502, detail="音频时长解析失败") start_sec = 0.0 clip_duration = min(audio_duration, float(target_duration_sec)) if audio_duration > target_duration_sec: max_start = max(0.0, audio_duration - target_duration_sec) start_sec = random.uniform(0.0, max_start) if max_start > 0 else 0.0 clip_duration = float(target_duration_sec) clip_path = os.path.join(tmpdir, "audio_clip.mp3") _extract_audio_segment(audio_source_path, start_sec, clip_duration, clip_path) if clip_duration < target_duration_sec: padded_path = os.path.join(tmpdir, "audio_clip_pad.mp3") _pad_audio_segment( clip_path, target_duration_sec - clip_duration, target_duration_sec, padded_path, ) clip_path = padded_path output_path = os.path.join(tmpdir, "merged.mp4") _run_ffmpeg_merge(base_video_path, clip_path, output_path) with open(output_path, "rb") as file_handle: merged_bytes = file_handle.read() return merged_bytes, { "bgm_start_sec": round(start_sec, 3), "bgm_duration": target_duration_sec, } def _get_or_create_session(db: Session, user: AuthedUser, lover: Lover, session_id: Optional[int]) -> ChatSession: if session_id: session = ( db.query(ChatSession) .filter(ChatSession.id == session_id, ChatSession.user_id == user.id) .first() ) if not session: raise HTTPException(status_code=404, detail="会话不存在") return session active_sessions = ( db.query(ChatSession) .filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active") .with_for_update() .order_by(ChatSession.created_at.desc()) .all() ) if active_sessions: primary = active_sessions[0] for extra in active_sessions[1:]: if extra.status == "active": extra.status = "archived" db.add(extra) return primary now = datetime.utcnow() session = ChatSession( user_id=user.id, lover_id=lover.id, model=settings.LLM_MODEL or "qwen-flash", status="active", last_message_at=now, created_at=now, updated_at=now, inner_voice_enabled=False, ) db.add(session) db.flush() return session def _next_seq(db: Session, session_id: int) -> int: last_msg = ( db.query(ChatMessage) .filter(ChatMessage.session_id == session_id) .order_by(ChatMessage.seq.desc()) .first() ) return (last_msg.seq if last_msg and last_msg.seq else 0) + 1 def _mark_task_failed(db: Session, task_id: int, msg: str): task = ( db.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if not task: return payload = task.payload or {} lover_msg_id = payload.get("lover_message_id") session_id = payload.get("session_id") if lover_msg_id: lover_msg = ( db.query(ChatMessage) .filter(ChatMessage.id == lover_msg_id) .with_for_update() .first() ) if lover_msg: lover_msg.content = f"跳舞视频生成失败:{msg}" lover_msg.extra = { **(lover_msg.extra or {}), "generation_status": "failed", "error_msg": msg, "generation_task_id": task_id, } lover_msg.tts_status = lover_msg.tts_status or "pending" db.add(lover_msg) if session_id: session = ( db.query(ChatSession) .filter(ChatSession.id == session_id) .with_for_update() .first() ) if session: session.last_message_at = datetime.utcnow() db.add(session) task.status = "failed" task.error_msg = msg[:255] task.updated_at = datetime.utcnow() db.add(task) db.commit() def _process_dance_task(task_id: int): # Phase 1: 读取任务并标记 running,快速退出连接 task = None prompt = "" image_url = "" session_id = None user_message_id = None lover_message_id = None dash_task_id = None user_id = None lover_id = None bgm_song_id = None bgm_audio_url_raw = None bgm_audio_url = None bgm_start_sec = None try: db = SessionLocal() task = ( db.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if not task or task.status in ("succeeded", "failed"): db.rollback() return user_id = task.user_id lover_id = task.lover_id payload = task.payload or {} prompt = payload.get("prompt") or "" image_url = payload.get("image_url") session_id = payload.get("session_id") user_message_id = payload.get("user_message_id") lover_message_id = payload.get("lover_message_id") dash_task_id = payload.get("dashscope_task_id") # 基础存在性校验,快速失败 lover = db.query(Lover).filter(Lover.id == lover_id).first() user_row = db.query(User).filter(User.id == user_id).first() if not lover: raise HTTPException(status_code=404, detail="恋人不存在,请重新创建") if not lover.image_url: raise HTTPException(status_code=400, detail="请先生成并确认恋人形象") if not user_row: raise HTTPException(status_code=404, detail="用户不存在") if (user_row.video_gen_remaining or 0) <= 0: raise HTTPException(status_code=400, detail="视频生成次数不足") image_url = image_url or lover.image_url task.status = "running" task.updated_at = datetime.utcnow() db.add(task) db.commit() except HTTPException as exc: try: _mark_task_failed( db, task_id, str(exc.detail) if hasattr(exc, "detail") else str(exc), ) except Exception: pass finally: try: db.close() except Exception: pass return except Exception as exc: try: _mark_task_failed(db, task_id, str(exc)[:255]) except Exception: pass finally: try: db.close() except Exception: pass return finally: try: db.close() except Exception: pass # Phase 2: 提交/轮询 DashScope(无 DB 连接) try: if not dash_task_id: dash_task_id = _submit_video_task(prompt, image_url) with SessionLocal() as db_tmp: task_row = ( db_tmp.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if task_row: task_row.payload = {**(task_row.payload or {}), "dashscope_task_id": dash_task_id} task_row.status = "running" task_row.updated_at = datetime.utcnow() db_tmp.add(task_row) db_tmp.commit() dash_video_url = _poll_video_url(dash_task_id) with SessionLocal() as db_tmp: bgm_song = _pick_random_bgm(db_tmp) bgm_song_id = bgm_song.id bgm_audio_url_raw = bgm_song.audio_url bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw task_row = ( db_tmp.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if task_row: task_row.payload = { **(task_row.payload or {}), "bgm_song_id": bgm_song_id, "bgm_audio_url": bgm_audio_url, "bgm_audio_url_raw": bgm_audio_url_raw, } task_row.updated_at = datetime.utcnow() db_tmp.add(task_row) db_tmp.commit() merged_bytes, bgm_meta = _merge_dance_video_with_bgm( dash_video_url, bgm_audio_url, DANCE_TARGET_DURATION_SEC, ) bgm_start_sec = bgm_meta.get("bgm_start_sec") safe_prompt_tag = "prompt" object_name = f"lover/{lover_id}/dance/{int(time.time())}_{safe_prompt_tag}.mp4" oss_url = _upload_to_oss(merged_bytes, object_name) except Exception as exc: try: db_fail = SessionLocal() _mark_task_failed(db_fail, task_id, str(exc) if not hasattr(exc, "detail") else str(exc.detail)) except Exception: pass return # Phase 3: 回写结果与消息(短事务) db = SessionLocal() try: task_row = ( db.query(GenerationTask) .filter(GenerationTask.id == task_id) .with_for_update() .first() ) if not task_row: db.rollback() return lover = db.query(Lover).filter(Lover.id == task_row.lover_id).first() user_row = ( db.query(User) .filter(User.id == task_row.user_id) .with_for_update() .first() ) if not lover or not user_row: raise HTTPException(status_code=404, detail="用户或恋人不存在") # 获取/创建会话 session = None if session_id: session = ( db.query(ChatSession) .filter(ChatSession.id == session_id, ChatSession.user_id == user_row.id) .with_for_update() .first() ) if not session: session = _get_or_create_session( db, AuthedUser( id=user_row.id, reg_step=user_row.reg_step or 0, gender=user_row.gender or 0, nickname=user_row.nickname or user_row.username or "", token=user_row.token or "", ), lover, None, ) # 用户消息 user_msg = None if user_message_id: user_msg = ( db.query(ChatMessage) .filter(ChatMessage.id == user_message_id, ChatMessage.session_id == session.id) .with_for_update() .first() ) if not user_msg: now = datetime.utcnow() next_seq = _next_seq(db, session.id) user_msg = ChatMessage( session_id=session.id, user_id=user_row.id, lover_id=lover.id, role="user", content_type="text", content=prompt, seq=next_seq, created_at=now, model=settings.LLM_MODEL or "qwen-flash", ) db.add(user_msg) db.flush() next_seq = user_msg.seq or _next_seq(db, session.id) # 恋人消息 lover_msg = None if lover_message_id: lover_msg = ( db.query(ChatMessage) .filter(ChatMessage.id == lover_message_id, ChatMessage.session_id == session.id) .with_for_update() .first() ) if not lover_msg: lover_msg = ChatMessage( session_id=session.id, user_id=user_row.id, lover_id=lover.id, role="lover", content_type="text", content="正在为你生成跳舞视频,完成后会自动更新此消息", seq=next_seq + 1, created_at=datetime.utcnow(), model=settings.LLM_MODEL or "qwen-flash", extra={ "generation_task_id": task_row.id, "generation_status": "pending", "prompt": prompt, }, tts_status="pending", ) db.add(lover_msg) db.flush() lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{oss_url}" lover_msg.extra = { **(lover_msg.extra or {}), "video_url": oss_url, "dashscope_video_url": dash_video_url, "dashscope_task_id": dash_task_id, "generation_task_id": task_row.id, "prompt": prompt, "model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash", "resolution": settings.VIDEO_GEN_RESOLUTION or "480P", "duration": DANCE_TARGET_DURATION_SEC, "base_duration": settings.VIDEO_GEN_DURATION or 5, "watermark": False, "generation_status": "succeeded", "bgm_song_id": bgm_song_id, "bgm_audio_url": bgm_audio_url, "bgm_audio_url_raw": bgm_audio_url_raw, "bgm_start_sec": bgm_start_sec, "bgm_duration": DANCE_TARGET_DURATION_SEC, } lover_msg.tts_status = lover_msg.tts_status or "pending" db.add(lover_msg) # 扣减额度并刷新会话时间(不允许扣成负数,只扣一次) already_deducted = (task_row.payload or {}).get("deducted") remaining = user_row.video_gen_remaining or 0 if remaining > 0 and not already_deducted: user_row.video_gen_remaining = remaining - 1 session.last_message_at = datetime.utcnow() db.add(user_row) db.add(session) task_row.status = "succeeded" task_row.result_url = oss_url task_row.payload = { **(task_row.payload or {}), "dashscope_video_url": dash_video_url, "dashscope_task_id": dash_task_id, "session_id": session.id, "user_message_id": user_msg.id, "lover_message_id": lover_msg.id, "deducted": True, "output_duration": DANCE_TARGET_DURATION_SEC, "bgm_song_id": bgm_song_id, "bgm_audio_url": bgm_audio_url, "bgm_audio_url_raw": bgm_audio_url_raw, "bgm_start_sec": bgm_start_sec, "bgm_duration": DANCE_TARGET_DURATION_SEC, } task_row.updated_at = datetime.utcnow() db.add(task_row) db.commit() except HTTPException as exc: db.rollback() try: _mark_task_failed( db, task_id, str(exc.detail) if hasattr(exc, "detail") else str(exc), ) except Exception: pass except Exception as exc: db.rollback() try: _mark_task_failed(db, task_id, str(exc)[:255]) except Exception: pass finally: db.close() @router.post("/generate", response_model=ApiResponse[DanceTaskStatusOut]) def generate_dance_video( payload: DanceGenerateIn, background_tasks: BackgroundTasks, db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), ): prompt = (payload.prompt or "").strip() if not prompt: raise HTTPException(status_code=400, detail="请输入想要跳的舞/唱的歌/动作描述") lover = db.query(Lover).filter(Lover.user_id == user.id).first() if not lover: raise HTTPException(status_code=404, detail="恋人不存在,请先完成创建流程") if not lover.image_url: raise HTTPException(status_code=400, detail="请先生成并确认恋人形象") user_row = ( db.query(User) .filter(User.id == user.id) .with_for_update() .first() ) if not user_row: raise HTTPException(status_code=404, detail="用户不存在") if (user_row.video_gen_remaining or 0) <= 0: raise HTTPException(status_code=400, detail="视频生成次数不足") pending_task = ( db.query(GenerationTask) .filter( GenerationTask.user_id == user.id, GenerationTask.task_type == "video", GenerationTask.status.in_(["pending", "running"]), ) .first() ) if pending_task: raise HTTPException(status_code=409, detail="已有视频生成任务进行中,请稍后再试") idem_key_src = f"video:{user.id}:{lover.image_url}:{prompt}" idem_key = hashlib.sha256(idem_key_src.encode("utf-8")).hexdigest() task = GenerationTask( user_id=user.id, lover_id=lover.id, task_type="video", status="pending", idempotency_key=idem_key, payload={ "image_url": lover.image_url, "prompt": prompt, "model": settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash", "resolution": settings.VIDEO_GEN_RESOLUTION or "480P", "duration": settings.VIDEO_GEN_DURATION or 5, "watermark": False, "output_duration": DANCE_TARGET_DURATION_SEC, }, created_at=datetime.utcnow(), updated_at=datetime.utcnow(), ) db.add(task) try: db.flush() except IntegrityError: db.rollback() existing = ( db.query(GenerationTask) .filter(GenerationTask.idempotency_key == idem_key) .first() ) if existing and existing.status in ("pending", "running"): raise HTTPException(status_code=409, detail="已提交相同的视频生成任务,请稍后查看结果") # 既往任务已结束,避免幂等键冲突,生成新的短哈希重试 retry_key = hashlib.sha256(f"{idem_key}:{time.time()}".encode()).hexdigest() task.idempotency_key = retry_key db.add(task) db.flush() # 预留会话与消息序号,便于生成完成后不会被后续消息“插队” session = _get_or_create_session(db, user, lover, None) now = datetime.utcnow() next_seq = _next_seq(db, session.id) user_msg = ChatMessage( session_id=session.id, user_id=user.id, lover_id=lover.id, role="user", content_type="text", content=prompt, seq=next_seq, created_at=now, model=settings.LLM_MODEL or "qwen-flash", ) db.add(user_msg) db.flush() lover_msg = ChatMessage( session_id=session.id, user_id=user.id, lover_id=lover.id, role="lover", content_type="text", content="正在为你生成跳舞视频,完成后会自动更新此消息", seq=next_seq + 1, created_at=datetime.utcnow(), model=settings.LLM_MODEL or "qwen-flash", extra={ "generation_task_id": task.id, "generation_status": "pending", "prompt": prompt, }, tts_status="pending", ) db.add(lover_msg) db.flush() session.last_message_at = datetime.utcnow() db.add(session) task.payload = { **(task.payload or {}), "session_id": session.id, "user_message_id": user_msg.id, "lover_message_id": lover_msg.id, } db.add(task) db.commit() background_tasks.add_task(_process_dance_task, task.id) return success_response( DanceTaskStatusOut( generation_task_id=task.id, status="pending", dashscope_task_id="", video_url="", session_id=session.id, user_message_id=user_msg.id, lover_message_id=lover_msg.id, error_msg=None, ), msg="视频生成任务已提交,正在生成", ) @router.get("/generate/{task_id}", response_model=ApiResponse[DanceTaskStatusOut]) def get_dance_task( task_id: int, background_tasks: BackgroundTasks, db: Session = Depends(get_db), user: AuthedUser = Depends(get_current_user), ): task = ( db.query(GenerationTask) .filter( GenerationTask.id == task_id, GenerationTask.user_id == user.id, GenerationTask.task_type == "video", ) .first() ) if not task: raise HTTPException(status_code=404, detail="任务不存在") status_msg_map = { "pending": "视频生成中", "running": "视频生成中", "succeeded": "视频生成成功", "failed": "视频生成失败", } resp_msg = status_msg_map.get(task.status or "", "查询成功") # 若状态仍为 pending/running,补偿性触发处理(同步 + 异步双保险) if task.status in ("pending", "running"): # 避免阻塞,优先异步处理 background_tasks.add_task(_process_dance_task, task.id) # 用新会话读取最新状态,避免当前事务快照阻塞 with SessionLocal() as tmp: latest = ( tmp.query(GenerationTask) .filter( GenerationTask.id == task.id, GenerationTask.user_id == user.id, GenerationTask.task_type == "video", ) .first() ) if latest: task = latest # 若仍然卡在 running,直接查询 dashscope,并回填结果/失败,避免无限等待 if task.status in ("pending", "running") and (task.payload or {}).get("dashscope_task_id"): status, dash_video = _fetch_dashscope_status((task.payload or {}).get("dashscope_task_id")) with SessionLocal() as tmp: current = ( tmp.query(GenerationTask) .filter( GenerationTask.id == task.id, GenerationTask.user_id == user.id, GenerationTask.task_type == "video", ) .with_for_update() .first() ) if current: if status == "SUCCEEDED": dash_url = ( dash_video or (current.payload or {}).get("dashscope_video_url") or current.result_url ) final_url = current.result_url or dash_url bgm_song_id = (current.payload or {}).get("bgm_song_id") bgm_audio_url = (current.payload or {}).get("bgm_audio_url") bgm_start_sec = (current.payload or {}).get("bgm_start_sec") merge_failed = False if dash_url and not _is_own_oss_url(final_url or ""): try: bgm_song = _pick_random_bgm(tmp) bgm_song_id = bgm_song.id bgm_audio_url_raw = bgm_song.audio_url bgm_audio_url = _cdnize(bgm_audio_url_raw) or bgm_audio_url_raw merged_bytes, bgm_meta = _merge_dance_video_with_bgm( dash_url, bgm_audio_url, DANCE_TARGET_DURATION_SEC, ) bgm_start_sec = bgm_meta.get("bgm_start_sec") object_name = f"lover/{current.lover_id}/dance/{int(time.time())}_prompt.mp4" final_url = _upload_to_oss(merged_bytes, object_name) current.payload = { **(current.payload or {}), "bgm_song_id": bgm_song_id, "bgm_audio_url": bgm_audio_url, "bgm_audio_url_raw": bgm_audio_url_raw, } except Exception as exc: current.status = "failed" current.error_msg = str(exc) if not hasattr(exc, "detail") else str(exc.detail) current.updated_at = datetime.utcnow() tmp.add(current) tmp.commit() task = current merge_failed = True if not merge_failed: current.status = "succeeded" current.result_url = final_url current.payload = { **(current.payload or {}), "dashscope_video_url": dash_url or (current.payload or {}).get("dashscope_video_url"), "output_duration": DANCE_TARGET_DURATION_SEC, "bgm_song_id": bgm_song_id, "bgm_audio_url": bgm_audio_url, "bgm_start_sec": bgm_start_sec, "bgm_duration": DANCE_TARGET_DURATION_SEC, } current.updated_at = datetime.utcnow() # 尝试同步更新聊天占位消息与额度(避免重复扣,检查 deducted 标记) try: lover = tmp.query(Lover).filter(Lover.id == current.lover_id).first() user_row = ( tmp.query(User) .filter(User.id == current.user_id) .with_for_update() .first() ) session = None if current.payload.get("session_id"): session = ( tmp.query(ChatSession) .filter( ChatSession.id == current.payload.get("session_id"), ChatSession.user_id == current.user_id, ) .with_for_update() .first() ) lover_msg = None if current.payload.get("lover_message_id"): lover_msg = ( tmp.query(ChatMessage) .filter(ChatMessage.id == current.payload.get("lover_message_id")) .with_for_update() .first() ) if lover_msg: lover_msg.content = f"为你生成了一段跳舞视频,点击查看:{final_url}" lover_msg.extra = { **(lover_msg.extra or {}), "video_url": final_url, "dashscope_video_url": dash_url, "dashscope_task_id": current.payload.get("dashscope_task_id"), "generation_task_id": current.id, "prompt": current.payload.get("prompt"), "model": current.payload.get("model") or settings.VIDEO_GEN_MODEL or "wan2.2-i2v-flash", "resolution": current.payload.get("resolution") or settings.VIDEO_GEN_RESOLUTION or "480P", "duration": current.payload.get("output_duration") or DANCE_TARGET_DURATION_SEC, "base_duration": current.payload.get("duration") or settings.VIDEO_GEN_DURATION or 5, "watermark": current.payload.get("watermark", False), "generation_status": "succeeded", "bgm_song_id": current.payload.get("bgm_song_id"), "bgm_audio_url": current.payload.get("bgm_audio_url"), "bgm_audio_url_raw": current.payload.get("bgm_audio_url_raw"), "bgm_start_sec": current.payload.get("bgm_start_sec"), "bgm_duration": current.payload.get("bgm_duration") or DANCE_TARGET_DURATION_SEC, } lover_msg.tts_status = lover_msg.tts_status or "pending" tmp.add(lover_msg) if user_row and not (current.payload or {}).get("deducted"): remaining = user_row.video_gen_remaining or 0 if remaining > 0: user_row.video_gen_remaining = remaining - 1 current.payload = {**(current.payload or {}), "deducted": True} tmp.add(user_row) if session: session.last_message_at = datetime.utcnow() tmp.add(session) except Exception: # 避免阻断主流程 pass tmp.add(current) tmp.commit() task = current elif status == "FAILED": current.status = "failed" current.error_msg = "视频生成失败(DashScope 返回失败)" current.updated_at = datetime.utcnow() tmp.add(current) tmp.commit() task = current # 用新会话读取最终状态,避免 DetachedInstanceError with SessionLocal() as snap: latest = ( snap.query(GenerationTask) .filter( GenerationTask.id == task_id, GenerationTask.user_id == user.id, GenerationTask.task_type == "video", ) .first() ) if not latest: raise HTTPException(status_code=404, detail="任务不存在") payload = latest.payload or {} resp_msg = status_msg_map.get(latest.status or "", resp_msg) return success_response( DanceTaskStatusOut( generation_task_id=latest.id, status=latest.status, dashscope_task_id=str(payload.get("dashscope_task_id") or ""), video_url=latest.result_url or payload.get("dashscope_video_url") or "", session_id=int(payload.get("session_id") or 0), user_message_id=int(payload.get("user_message_id") or 0), lover_message_id=int(payload.get("lover_message_id") or 0), error_msg=latest.error_msg, ), msg=resp_msg, )