Ai_GirlFriend/lover/routers/sing.py
2026-03-02 18:57:11 +08:00

3096 lines
115 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import hashlib
import logging
import math
import os
import shutil
import subprocess
import tempfile
import threading
import time
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import List, Optional
import oss2
import requests
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ..config import settings
from ..db import SessionLocal, get_db
from ..deps import AuthedUser, get_current_user
from ..models import (
ChatMessage,
ChatSession,
EmoDetectCache,
GenerationTask,
Lover,
MusicLibrary,
SingBaseVideo,
SingSongVideo,
SongLibrary,
SongSegment,
SongSegmentVideo,
User,
)
from ..response import ApiResponse, success_response
from ..task_queue import sing_task_queue
try:
import imageio_ffmpeg # type: ignore
except Exception: # pragma: no cover
imageio_ffmpeg = None
logger = logging.getLogger("sing")
router = APIRouter(prefix="/sing", tags=["sing"])
def _ffmpeg_bin() -> str:
"""Prefer system ffmpeg; fallback to imageio-ffmpeg bundled binary."""
found = shutil.which("ffmpeg")
if found:
return found
if imageio_ffmpeg is not None:
try:
return imageio_ffmpeg.get_ffmpeg_exe()
except Exception:
pass
return "ffmpeg"
def _ffprobe_bin() -> str:
"""Prefer system ffprobe; fallback to imageio-ffmpeg bundled binary if available."""
found = shutil.which("ffprobe")
if found:
return found
# imageio-ffmpeg only guarantees ffmpeg; most builds include ffprobe alongside.
if imageio_ffmpeg is not None:
try:
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe.exe")
if os.path.exists(candidate):
return candidate
candidate = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe")
if os.path.exists(candidate):
return candidate
except Exception:
pass
return "ffprobe"
SING_BASE_MODEL = "wan2.5-i2v-preview"
SING_BASE_RESOLUTION = "480P"
SING_WAN26_MODEL = "wan2.6-i2v-flash"
SING_WAN26_RESOLUTION = "720P"
SING_BASE_DURATION = 5
SING_MAX_DURATION = 30 # 唱歌视频最大时长(秒)
SING_BASE_PROMPT = (
"front-facing full-body, modest outfit; camera locked on tripod; "
"head and neck fixed, body still; "
"natural singing with soft lip articulation and varied mouth shapes, small-to-medium opening, smooth transitions; "
"subtle jaw motion only; gentle expressive hand gestures near chest; "
"realistic lighting, sharp details; pose consistent from start to end"
)
SING_BASE_NEGATIVE_PROMPT = (
"nudity, cleavage, lingerie, bikini, see-through, sexualized, provocative; "
"camera movement, zoom, pan, shake, head turning, head bobbing, body swaying, dancing, walking; "
"exaggerated expressions; chewing, biting, eating, fish mouth, O-shaped mouth, "
"mouth stuck open, mouth stuck closed, repetitive mechanical mouth motion, extreme mouth shapes, lip deformation; "
"blur, artifacts, watermark, text, bad hands, extra fingers"
)
SING_WAN26_PROMPT = (
"front-facing full-body, modest outfit; camera locked on tripod; "
"head and neck fixed, body still; "
"lip-sync to the provided audio lyrics with natural mouth articulation and varied mouth shapes, "
"clear closures between syllables, smooth transitions; "
"subtle jaw motion only; gentle expressive hand gestures near chest; "
"natural facial expressions matching the song emotion: smile for happy songs, gentle sadness for melancholic songs, "
"soft eye movements and eyebrow raises to convey emotion; "
"realistic lighting, sharp details; pose consistent from start to end"
)
SING_WAN26_NEGATIVE_PROMPT = (
"nudity, cleavage, lingerie, bikini, see-through, sexualized, provocative; "
"camera movement, zoom, pan, shake, head turning, head bobbing, body swaying, dancing, walking; "
"exaggerated expressions; chewing, biting, eating, fish mouth, O-shaped mouth, "
"mouth stuck open, mouth stuck closed, out-of-sync mouth motion, repetitive mechanical mouth motion, "
"extreme mouth shapes, lip deformation; "
"blur, artifacts, watermark, text, bad hands, extra fingers"
)
SING_BASE_PROMPT_EXTEND = False
SING_BASE_WAIT_SECONDS = 180
SING_MERGE_WAIT_SECONDS = 60
SING_REQUEUE_COOLDOWN_SECONDS = 5
EMO_DETECT_MODEL = "emo-detect-v1"
EMO_MODEL = "emo-v1"
EMO_RATIO = "3:4"
EMO_STYLE_LEVEL = "normal"
EMO_SEGMENT_SECONDS = 60
EMO_TASK_TIMEOUT_SECONDS = 1800
EMO_BACKFILL_STALE_SECONDS = 60
EMO_BACKFILL_MIN_INTERVAL_SECONDS = 20
EMO_CONTENT_SAFETY_CODE = "DataInspectionFailed"
EMO_CONTENT_SAFETY_MESSAGE = "歌词内容触发了阿里云的内容安全审核机制"
_sing_merge_semaphore = threading.BoundedSemaphore(max(1, settings.SING_MERGE_MAX_CONCURRENCY or 1))
_sing_enqueue_lock = threading.Lock()
_sing_last_enqueue_at: dict[int, float] = {}
_emo_task_semaphore = threading.BoundedSemaphore(max(1, settings.EMO_MAX_CONCURRENCY or 1))
def _check_and_reset_vip_video_gen(user_row: User, db: Session) -> None:
"""检查并重置 VIP 用户的视频生成次数"""
if not user_row:
return
# 检查是否是 VIP 用户vip_endtime 是 Unix 时间戳)
current_timestamp = int(datetime.utcnow().timestamp())
is_vip = user_row.vip_endtime and user_row.vip_endtime > current_timestamp
if not is_vip:
return
# 获取上次重置日期
last_reset = user_row.video_gen_reset_date
today = datetime.utcnow().date()
# 如果是新的一天,重置次数
if not last_reset or last_reset < today:
user_row.video_gen_remaining = 2 # VIP 用户每天 2 次
user_row.video_gen_reset_date = today
db.add(user_row)
db.flush()
@contextmanager
def _semaphore_guard(semaphore: threading.BoundedSemaphore):
semaphore.acquire()
try:
yield
finally:
semaphore.release()
def _cdnize(url: Optional[str]) -> Optional[str]:
"""
将相对路径补全为可访问 URL。优先使用 CDN其次 bucket+endpoint最后兜底固定域名。
"""
if not url:
return url
cleaned = url.strip()
if cleaned.startswith("http://") or cleaned.startswith("https://"):
return cleaned
# 去掉首个斜杠,防止双斜杠
cleaned = cleaned.lstrip("/")
if settings.ALIYUN_OSS_CDN_DOMAIN:
return f"{settings.ALIYUN_OSS_CDN_DOMAIN.rstrip('/')}/{cleaned}"
if settings.ALIYUN_OSS_BUCKET_NAME and settings.ALIYUN_OSS_ENDPOINT:
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/").replace("https://", "").replace("http://", "")
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint}/{cleaned}"
# 兜底:项目历史使用的公开域名
return f"https://nvlovers.oss-cn-qingdao.aliyuncs.com/{cleaned}"
def _extract_error_text(exc: Exception) -> str:
if isinstance(exc, HTTPException):
detail = exc.detail
return detail if isinstance(detail, str) else str(detail)
return str(exc)
def _is_content_safety_error(value: Optional[str]) -> bool:
if not value:
return False
return EMO_CONTENT_SAFETY_CODE.lower() in value.lower()
def _build_sing_message_content(video_url: str, content_safety_blocked: bool) -> str:
if content_safety_blocked:
if video_url:
return f"{EMO_CONTENT_SAFETY_MESSAGE},已生成部分视频,点击查看:{video_url}"
return EMO_CONTENT_SAFETY_MESSAGE
if video_url:
return f"为你生成了一段唱歌视频,点击查看:{video_url}"
return "为你生成了一段唱歌视频"
def _upload_to_oss(file_bytes: bytes, object_name: str) -> str:
"""上传到 OSS返回可访问 URL优先 CDN 域名)。"""
if not settings.ALIYUN_OSS_ACCESS_KEY_ID or not settings.ALIYUN_OSS_ACCESS_KEY_SECRET:
raise HTTPException(status_code=500, detail="未配置 OSS Key")
if not settings.ALIYUN_OSS_BUCKET_NAME or not settings.ALIYUN_OSS_ENDPOINT:
raise HTTPException(status_code=500, detail="未配置 OSS Bucket/Endpoint")
auth = oss2.Auth(settings.ALIYUN_OSS_ACCESS_KEY_ID, settings.ALIYUN_OSS_ACCESS_KEY_SECRET)
endpoint = settings.ALIYUN_OSS_ENDPOINT.rstrip("/")
bucket = oss2.Bucket(auth, endpoint, settings.ALIYUN_OSS_BUCKET_NAME)
bucket.put_object(object_name, file_bytes)
cdn = settings.ALIYUN_OSS_CDN_DOMAIN
if cdn:
return f"{cdn.rstrip('/')}/{object_name}"
return f"https://{settings.ALIYUN_OSS_BUCKET_NAME}.{endpoint.replace('https://', '').replace('http://', '')}/{object_name}"
def _hash_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _hash_file(path: str) -> str:
hash_obj = hashlib.sha256()
with open(path, "rb") as file_handle:
while True:
chunk = file_handle.read(1024 * 1024)
if not chunk:
break
hash_obj.update(chunk)
return hash_obj.hexdigest()
def _build_prompt_hash(
prompt: str,
negative_prompt: Optional[str],
audio_hash: Optional[str] = None,
use_audio: bool = False,
) -> str:
base = f"{prompt}\n--NEG--\n{negative_prompt}" if negative_prompt else prompt
if use_audio and audio_hash:
base = f"{base}\n--AUDIO--\n{audio_hash}"
return _hash_text(base)
def _resolve_sing_base_config(_: Optional[str]) -> tuple[str, str, int, bool]:
return SING_WAN26_MODEL, SING_WAN26_RESOLUTION, SING_BASE_DURATION, True
def _resolve_sing_prompts(model: str) -> tuple[str, str]:
if model == SING_WAN26_MODEL:
return SING_WAN26_PROMPT, SING_WAN26_NEGATIVE_PROMPT
return SING_BASE_PROMPT, SING_BASE_NEGATIVE_PROMPT
def _download_to_path(url: str, target_path: str):
try:
logger.info(f"开始下载文件: {url}")
resp = requests.get(url, stream=True, timeout=30)
except Exception as exc:
logger.error(f"文件下载失败 - URL: {url}, 错误: {exc}")
raise HTTPException(status_code=502, detail="文件下载失败") from exc
if resp.status_code != 200:
logger.error(f"文件下载失败 - URL: {url}, 状态码: {resp.status_code}")
raise HTTPException(status_code=502, detail="文件下载失败")
try:
with open(target_path, "wb") as file_handle:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if chunk:
file_handle.write(chunk)
logger.info(f"文件下载成功: {url} -> {target_path}")
finally:
resp.close()
def _emo_detect(image_url: str, ratio: str) -> dict:
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
payload = {
"model": EMO_DETECT_MODEL,
"input": {"image_url": image_url},
"parameters": {"ratio": ratio},
}
headers = {
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
"Content-Type": "application/json",
}
try:
resp = requests.post(
"https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/face-detect",
headers=headers,
json=payload,
timeout=15,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="调用EMO检测失败") from exc
if resp.status_code != 200:
msg = resp.text
try:
msg = resp.json().get("message") or msg
except Exception:
pass
raise HTTPException(status_code=502, detail=f"EMO检测失败: {msg}")
try:
data = resp.json()
except Exception as exc:
raise HTTPException(status_code=502, detail="EMO检测返回解析失败") from exc
output = data.get("output") or {}
return {
"check_pass": bool(output.get("check_pass")),
"face_bbox": output.get("face_bbox"),
"ext_bbox": output.get("ext_bbox"),
"raw": data,
}
def _ensure_emo_detect_cache(
db: Session,
lover_id: int,
image_url: str,
image_hash: str,
ratio: str,
) -> EmoDetectCache:
cached = (
db.query(EmoDetectCache)
.filter(EmoDetectCache.image_hash == image_hash, EmoDetectCache.ratio == ratio)
.first()
)
if cached:
return cached
result = _emo_detect(image_url, ratio)
cached = EmoDetectCache(
lover_id=lover_id,
image_url=image_url,
image_hash=image_hash,
ratio=ratio,
check_pass=result.get("check_pass", False),
face_bbox=result.get("face_bbox"),
ext_bbox=result.get("ext_bbox"),
raw_response=result.get("raw"),
)
db.add(cached)
try:
db.flush()
except IntegrityError:
db.rollback()
cached = (
db.query(EmoDetectCache)
.filter(EmoDetectCache.image_hash == image_hash, EmoDetectCache.ratio == ratio)
.first()
)
if cached:
return cached
raise
return cached
def _probe_media_duration(path: str) -> Optional[float]:
command = [
_ffprobe_bin(),
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path,
]
try:
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError:
return None
except subprocess.CalledProcessError:
return None
raw = result.stdout.decode("utf-8", errors="ignore").strip()
if not raw:
return None
try:
duration = float(raw)
except ValueError:
return None
if duration <= 0:
return None
return duration
def _run_ffmpeg_merge(video_path: str, audio_path: str, output_path: str):
audio_duration = _probe_media_duration(audio_path)
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-stream_loop",
"-1",
"-i",
video_path,
"-i",
audio_path,
]
if audio_duration:
command.extend(["-t", f"{audio_duration:.3f}"])
command += [
"-map",
"0:v:0",
"-map",
"1:a:0",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-shortest",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"ffmpeg 合成失败: {stderr[:200]}") from exc
def _strip_video_audio(video_bytes: bytes) -> bytes:
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input.mp4")
output_path = os.path.join(tmpdir, "output.mp4")
with open(input_path, "wb") as file_handle:
file_handle.write(video_bytes)
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-map",
"0:v:0",
"-c:v",
"copy",
"-an",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError:
fallback = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-map",
"0:v:0",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-an",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(fallback, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"ffmpeg 去除音轨失败: {stderr[:200]}") from exc
with open(output_path, "rb") as file_handle:
return file_handle.read()
def _merge_audio_video(base_video_url: str, audio_url: str) -> bytes:
with tempfile.TemporaryDirectory() as tmpdir:
video_path = os.path.join(tmpdir, "base.mp4")
audio_path = os.path.join(tmpdir, "audio.mp3")
output_path = os.path.join(tmpdir, "merged.mp4")
_download_to_path(base_video_url, video_path)
_download_to_path(audio_url, audio_path)
_run_ffmpeg_merge(video_path, audio_path, output_path)
with open(output_path, "rb") as file_handle:
return file_handle.read()
def _extract_audio_segment(
input_path: str,
start_sec: float,
duration_sec: float,
output_path: str,
):
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-ss",
f"{start_sec:.3f}",
"-t",
f"{duration_sec:.3f}",
"-vn",
"-acodec",
"libmp3lame",
"-b:a",
"128k",
"-ar",
"44100",
"-ac",
"2",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"音频分段失败: {stderr[:200]}") from exc
def _pad_audio_segment(
input_path: str,
pad_sec: float,
target_duration_sec: float,
output_path: str,
):
if pad_sec <= 0:
return
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-af",
f"apad=pad_dur={pad_sec:.3f}",
"-t",
f"{target_duration_sec:.3f}",
"-acodec",
"libmp3lame",
"-b:a",
"128k",
"-ar",
"44100",
"-ac",
"2",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"音频补齐失败: {stderr[:200]}") from exc
def _trim_video_duration(input_path: str, target_duration_sec: float, output_path: str):
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-t",
f"{target_duration_sec:.3f}",
"-c",
"copy",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError:
pass
fallback = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-i",
input_path,
"-t",
f"{target_duration_sec:.3f}",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(fallback, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"视频裁剪失败: {stderr[:200]}") from exc
def _trim_video_bytes(video_bytes: bytes, target_duration_sec: float) -> bytes:
if target_duration_sec <= 0:
return video_bytes
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input.mp4")
output_path = os.path.join(tmpdir, "trimmed.mp4")
with open(input_path, "wb") as file_handle:
file_handle.write(video_bytes)
_trim_video_duration(input_path, target_duration_sec, output_path)
with open(output_path, "rb") as file_handle:
return file_handle.read()
def _concat_video_files(video_paths: list[str], output_path: str):
if not video_paths:
raise HTTPException(status_code=500, detail="合成视频列表为空")
list_path = None
try:
with tempfile.NamedTemporaryFile("w", delete=False) as list_file:
list_path = list_file.name
for path in video_paths:
list_file.write(f"file '{path}'\n")
command = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-f",
"concat",
"-safe",
"0",
"-i",
list_path,
"-c",
"copy",
"-movflags",
"+faststart",
output_path,
]
try:
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except subprocess.CalledProcessError:
fallback = [
_ffmpeg_bin(),
"-y",
"-loglevel",
"error",
"-f",
"concat",
"-safe",
"0",
"-i",
list_path,
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-movflags",
"+faststart",
output_path,
]
with _semaphore_guard(_sing_merge_semaphore):
subprocess.run(fallback, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg 未安装或不可用") from exc
except subprocess.CalledProcessError as exc:
stderr = exc.stderr.decode("utf-8", errors="ignore") if exc.stderr else ""
raise HTTPException(status_code=502, detail=f"视频拼接失败: {stderr[:200]}") from exc
finally:
if list_path:
try:
os.remove(list_path)
except OSError:
pass
def _submit_sing_task(
prompt: str,
image_url: str,
negative_prompt: Optional[str],
audio_url: Optional[str],
model: str,
resolution: str,
duration: int,
) -> str:
"""
调用 i2v 模型,传入 image_url + prompt返回 dashscope task_id。
"""
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
base_prompt = prompt or SING_BASE_PROMPT
input_obj = {
"prompt": base_prompt,
"negative_prompt": negative_prompt,
"img_url": image_url,
}
if audio_url:
input_obj["audio_url"] = audio_url
parameters = {
"resolution": resolution,
"duration": duration,
"prompt_extend": SING_BASE_PROMPT_EXTEND,
}
if model == SING_WAN26_MODEL:
parameters["audio"] = True
input_payload = {
"model": model,
"input": input_obj,
"parameters": parameters,
}
headers = {
"X-DashScope-Async": "enable",
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
"Content-Type": "application/json",
}
try:
resp = requests.post(
"https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis",
headers=headers,
json=input_payload,
timeout=10,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="调用视频生成接口失败") from exc
if resp.status_code != 200:
msg = resp.text
try:
msg = resp.json().get("message") or msg
except Exception:
pass
raise HTTPException(status_code=502, detail=f"视频任务提交失败: {msg}")
try:
data = resp.json()
except Exception as exc:
raise HTTPException(status_code=502, detail="视频任务返回解析失败") from exc
task_id = (
data.get("output", {}).get("task_id")
or data.get("task_id")
or data.get("output", {}).get("id")
)
if not task_id:
raise HTTPException(status_code=502, detail="视频任务未返回 task_id")
return str(task_id)
def _submit_emo_video(
image_url: str,
audio_url: str,
face_bbox: list,
ext_bbox: list,
style_level: str,
) -> str:
if not settings.DASHSCOPE_API_KEY:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
input_obj = {
"image_url": image_url,
"audio_url": audio_url,
"face_bbox": face_bbox,
"ext_bbox": ext_bbox,
}
input_payload = {
"model": EMO_MODEL,
"input": input_obj,
"parameters": {"style_level": style_level},
}
headers = {
"X-DashScope-Async": "enable",
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
"Content-Type": "application/json",
}
logger.info(f"提交 EMO 视频生成任务model={EMO_MODEL}, style_level={style_level}")
logger.debug(f"请求参数: {input_payload}")
try:
resp = requests.post(
"https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis",
headers=headers,
json=input_payload,
timeout=15,
)
except Exception as exc:
logger.error(f"调用 EMO API 失败: {exc}")
raise HTTPException(status_code=502, detail="调用EMO视频生成失败") from exc
logger.info(f"EMO API 返回状态码: {resp.status_code}")
if resp.status_code != 200:
msg = resp.text
try:
msg = resp.json().get("message") or msg
except Exception:
pass
logger.error(f"EMO 任务提交失败: {msg}")
raise HTTPException(status_code=502, detail=f"EMO视频任务提交失败: {msg}")
try:
data = resp.json()
logger.info(f"EMO API 返回数据: {data}")
except Exception as exc:
logger.error(f"解析 EMO API 响应失败: {exc}")
raise HTTPException(status_code=502, detail="EMO视频任务返回解析失败") from exc
task_id = (
data.get("output", {}).get("task_id")
or data.get("task_id")
or data.get("output", {}).get("id")
)
if not task_id:
logger.error(f"EMO API 未返回 task_id完整响应: {data}")
raise HTTPException(status_code=502, detail="EMO视频任务未返回 task_id")
logger.info(f"EMO 任务提交成功task_id={task_id}")
return str(task_id)
def _update_task_status_in_db(dashscope_task_id: str, status: str, result_url: Optional[str] = None):
"""实时更新数据库中的任务状态"""
from lover.db import SessionLocal
from lover.models import GenerationTask
try:
with SessionLocal() as db:
# 查找对应的 GenerationTask
task = db.query(GenerationTask).filter(
GenerationTask.payload.like(f'%"dashscope_task_id": "{dashscope_task_id}"%')
).first()
if task:
# 映射状态
status_mapping = {
"PENDING": "pending",
"RUNNING": "running",
"SUCCEEDED": "succeeded",
"FAILED": "failed"
}
task_status = status_mapping.get(status, "pending")
task.status = task_status
if result_url:
task.result_url = result_url
# 更新 payload 中的状态
payload = task.payload or {}
payload["dashscope_status"] = status
payload["last_updated"] = time.time()
task.payload = payload
db.commit()
logger.info(f"📊 任务 {task.id} 状态已更新为: {task_status}")
except Exception as e:
logger.error(f"❌ 更新任务状态失败: {e}")
def _poll_video_url(task_id: str, timeout_seconds: int = 360) -> str:
logger.info(f"⏳ 开始轮询 DashScope 任务: {task_id}")
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
deadline = time.time() + max(60, timeout_seconds)
attempts = 0
while time.time() < deadline:
time.sleep(3)
attempts += 1
try:
resp = requests.get(query_url, headers=headers, timeout=8)
except Exception as e:
if attempts % 10 == 0:
logger.warning(f"⚠️ 轮询任务 {task_id}{attempts} 次请求失败: {e}")
continue
if resp.status_code != 200:
if attempts % 10 == 0:
logger.warning(f"⚠️ 轮询任务 {task_id}{attempts} 次返回状态码: {resp.status_code}")
continue
try:
data = resp.json()
except Exception as e:
if attempts % 10 == 0:
logger.warning(f"⚠️ 轮询任务 {task_id}{attempts} 次 JSON 解析失败: {e}")
continue
output = data.get("output") or {}
status_str = str(
output.get("task_status")
or data.get("task_status")
or data.get("status")
or ""
).upper()
# 每 5 次15秒记录一次进度并更新数据库
if attempts % 5 == 0:
logger.info(f"🔄 轮询任务 {task_id}{attempts} 次,状态: {status_str}")
# 实时更新数据库状态
_update_task_status_in_db(task_id, status_str, None)
if status_str == "SUCCEEDED":
results = output.get("results") or {}
url = (
results.get("video_url")
or output.get("video_url")
or data.get("video_url")
or data.get("output", {}).get("video_url")
)
if not url:
logger.error(f"❌ 任务 {task_id} 成功但未返回 URL")
raise HTTPException(status_code=502, detail="视频生成成功但未返回结果 URL")
logger.info(f"✅ 任务 {task_id} 生成成功!")
# 立即更新数据库状态为成功
_update_task_status_in_db(task_id, "SUCCEEDED", url)
return url
if status_str == "FAILED":
code = output.get("code") or data.get("code")
msg = output.get("message") or data.get("message") or "生成失败"
if code:
msg = f"{code}: {msg}"
logger.error(f"❌ 任务 {task_id} 生成失败: {msg}")
# 立即更新数据库状态为失败
_update_task_status_in_db(task_id, "FAILED", None)
raise HTTPException(status_code=502, detail=f"视频生成失败: {msg}")
logger.error(f"⏱️ 任务 {task_id} 轮询超时,共尝试 {attempts}")
raise HTTPException(status_code=504, detail="视频生成超时,请稍后重试")
def _query_dashscope_task_status(task_id: str) -> tuple[str, Optional[str], Optional[str]]:
if not settings.DASHSCOPE_API_KEY:
return "UNKNOWN", None, "未配置 DASHSCOPE_API_KEY"
headers = {"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}"}
query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
try:
resp = requests.get(query_url, headers=headers, timeout=8)
except Exception:
return "UNKNOWN", None, "请求失败"
if resp.status_code != 200:
return "UNKNOWN", None, resp.text
try:
data = resp.json()
except Exception:
return "UNKNOWN", None, "响应解析失败"
output = data.get("output") or {}
status_str = str(
output.get("task_status")
or data.get("task_status")
or data.get("status")
or ""
).upper()
if status_str == "SUCCEEDED":
results = output.get("results") or {}
url = (
results.get("video_url")
or output.get("video_url")
or data.get("video_url")
or data.get("output", {}).get("video_url")
)
if not url:
return "FAILED", None, "视频生成成功但未返回结果 URL"
return "SUCCEEDED", url, None
if status_str == "FAILED":
code = output.get("code") or data.get("code")
msg = output.get("message") or data.get("message") or "生成失败"
if code:
msg = f"{code}: {msg}"
return "FAILED", None, msg
if status_str:
return status_str, None, None
return "UNKNOWN", None, None
def _try_backfill_segment_video(segment_video_id: int, dashscope_task_id: Optional[str] = None) -> None:
"""
尝试从 DashScope 获取任务状态并更新数据库。
不返回任何对象,避免 SQLAlchemy DetachedInstanceError。
"""
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.first()
)
if not segment_video or segment_video.status != "running":
return
task_id = dashscope_task_id or segment_video.dashscope_task_id
if not task_id:
return
segment = (
db.query(SongSegment)
.filter(SongSegment.id == segment_video.segment_id)
.first()
)
segment_index = segment.segment_index if segment else segment_video.segment_id
info = {
"task_id": task_id,
"segment_index": segment_index,
"lover_id": segment_video.lover_id,
"song_id": segment_video.song_id,
"image_hash": segment_video.image_hash,
"style_level": segment_video.style_level,
"duration_ms": segment.duration_ms if segment else None,
}
status, dash_video_url, error_msg = _query_dashscope_task_status(info["task_id"])
if status == "SUCCEEDED" and dash_video_url:
try:
video_bytes = _download_binary(dash_video_url)
min_len_ms = int((settings.EMO_MIN_SEGMENT_SECONDS or 0) * 1000)
duration_ms = info.get("duration_ms") or 0
if min_len_ms > 0 and duration_ms > 0 and duration_ms < min_len_ms:
video_bytes = _trim_video_bytes(video_bytes, duration_ms / 1000.0)
object_name = (
f"lover/{info['lover_id']}/sing/segments/"
f"{info['song_id']}_{info['image_hash']}_{info['style_level']}_{info['segment_index']}.mp4"
)
segment_video_url = _upload_to_oss(video_bytes, object_name)
except Exception as exc:
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video and segment_video.status == "running":
segment_video.status = "failed"
segment_video.error_msg = (_extract_error_text(exc) or "生成失败")[:255]
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
return
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video and segment_video.status == "running":
segment_video.video_url = segment_video_url
segment_video.status = "succeeded"
segment_video.error_msg = None
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
return
if status == "FAILED":
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video and segment_video.status == "running":
segment_video.status = "failed"
segment_video.error_msg = str(error_msg or "生成失败")[:255]
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
return
def _wait_for_base_video(base_id: int, timeout: int) -> Optional[SingBaseVideo]:
deadline = time.time() + timeout
while time.time() < deadline:
with SessionLocal() as db:
base = (
db.query(SingBaseVideo)
.filter(SingBaseVideo.id == base_id)
.first()
)
if base and base.status in ("succeeded", "failed"):
return base
time.sleep(3)
return None
def _wait_for_merge_video(merge_id: int, timeout: int) -> Optional[SingSongVideo]:
deadline = time.time() + timeout
while time.time() < deadline:
with SessionLocal() as db:
merge = (
db.query(SingSongVideo)
.filter(SingSongVideo.id == merge_id)
.first()
)
if merge and merge.status in ("succeeded", "failed"):
return merge
time.sleep(3)
return None
def _wait_for_segment_video(segment_video_id: int, timeout: int) -> Optional[SongSegmentVideo]:
deadline = time.time() + timeout
last_backfill_at = 0.0
while time.time() < deadline:
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.first()
)
if segment_video:
status = segment_video.status
if status in ("succeeded", "failed"):
# 在会话关闭前获取所有需要的属性
video_url = segment_video.video_url
error_msg = segment_video.error_msg
# 创建一个新对象返回,避免 DetachedInstanceError
result = SongSegmentVideo(
id=segment_video.id,
status=status,
video_url=video_url,
error_msg=error_msg,
)
return result
dash_task_id = segment_video.dashscope_task_id
else:
dash_task_id = None
now = time.time()
if dash_task_id and now - last_backfill_at >= EMO_BACKFILL_MIN_INTERVAL_SECONDS:
# 调用 backfill 后重新查询,不使用返回的对象
_try_backfill_segment_video(segment_video_id, dash_task_id)
last_backfill_at = now
# 重新查询状态
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.first()
)
if segment_video:
status = segment_video.status
if status in ("succeeded", "failed"):
video_url = segment_video.video_url
error_msg = segment_video.error_msg
result = SongSegmentVideo(
id=segment_video.id,
status=status,
video_url=video_url,
error_msg=error_msg,
)
return result
time.sleep(3)
return None
def _backfill_running_segments(payload: dict):
song_id = payload.get("song_id")
image_hash = payload.get("image_hash")
if not song_id or not image_hash:
return
ratio = payload.get("ratio") or EMO_RATIO
style_level = payload.get("style_level") or EMO_STYLE_LEVEL
cutoff = datetime.utcnow() - timedelta(seconds=EMO_BACKFILL_STALE_SECONDS)
with SessionLocal() as db:
running_segments = (
db.query(SongSegmentVideo)
.filter(
SongSegmentVideo.song_id == song_id,
SongSegmentVideo.image_hash == image_hash,
SongSegmentVideo.ratio == ratio,
SongSegmentVideo.style_level == style_level,
SongSegmentVideo.model == EMO_MODEL,
SongSegmentVideo.status == "running",
SongSegmentVideo.dashscope_task_id.isnot(None),
SongSegmentVideo.updated_at <= cutoff,
)
.order_by(SongSegmentVideo.id.asc())
.limit(2)
.all()
)
segment_ids = [segment.id for segment in running_segments]
segment_tasks = [segment.dashscope_task_id for segment in running_segments]
for segment_id, task_id in zip(segment_ids, segment_tasks):
_try_backfill_segment_video(segment_id, task_id)
def _download_binary(url: str) -> bytes:
try:
resp = requests.get(url, timeout=30)
except Exception as exc:
raise HTTPException(status_code=502, detail="唱歌视频下载失败") from exc
if resp.status_code != 200:
raise HTTPException(status_code=502, detail="唱歌视频下载失败")
return resp.content
def _build_emo_segment_plan(duration: float) -> list[dict]:
if duration <= 0:
return []
max_len = float(EMO_SEGMENT_SECONDS)
segment_count = max(1, math.ceil(duration / max_len))
segments = []
for idx in range(segment_count):
start_sec = idx * max_len
remaining = max(0.0, duration - start_sec)
segment_duration = min(max_len, remaining)
if segment_duration <= 0:
break
segments.append({"start_sec": start_sec, "duration_sec": segment_duration})
return segments
def _fetch_complete_segments(song_id: int, audio_hash: str, expected_count: int) -> Optional[list[dict]]:
if expected_count <= 0:
return None
min_len_ms = int((settings.EMO_MIN_SEGMENT_SECONDS or 0) * 1000)
with SessionLocal() as db:
segments = (
db.query(SongSegment)
.filter(SongSegment.song_id == song_id, SongSegment.audio_hash == audio_hash)
.order_by(SongSegment.segment_index.asc())
.all()
)
if len(segments) != expected_count:
return None
segments_map = {segment.segment_index: segment for segment in segments}
for idx in range(1, expected_count + 1):
segment = segments_map.get(idx)
if not segment or segment.status != "succeeded" or not segment.audio_url:
return None
return [
{
"id": segment.id,
"segment_index": segment.segment_index,
"audio_url": segment.audio_url,
"duration_ms": segment.duration_ms,
"emo_duration_ms": max(segment.duration_ms, min_len_ms) if min_len_ms > 0 else segment.duration_ms,
}
for segment in segments
]
def _ensure_song_segments(
song_id: int,
audio_url: str,
audio_hash_hint: Optional[str],
duration_sec_hint: Optional[int],
) -> tuple[list[dict], str, int]:
if audio_hash_hint and duration_sec_hint:
# 限制时长为最大30秒
limited_duration = min(duration_sec_hint, SING_MAX_DURATION)
expected_count = max(1, math.ceil(limited_duration / EMO_SEGMENT_SECONDS))
existing = _fetch_complete_segments(song_id, audio_hash_hint, expected_count)
if existing:
return existing, audio_hash_hint, limited_duration
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "song_audio")
_download_to_path(audio_url, input_path)
audio_hash = _hash_file(input_path)
duration = _probe_media_duration(input_path)
if not duration:
raise HTTPException(status_code=502, detail="音频时长获取失败")
# 限制音频时长为最大30秒
duration = min(duration, float(SING_MAX_DURATION))
duration_sec = int(math.ceil(duration))
segment_plan = _build_emo_segment_plan(duration)
expected_count = len(segment_plan) or 1
with SessionLocal() as db:
song = db.query(SongLibrary).filter(SongLibrary.id == song_id).with_for_update().first()
if song:
song.audio_hash = audio_hash
song.duration_sec = duration_sec
db.add(song)
db.commit()
existing_segments = {}
with SessionLocal() as db:
segments = (
db.query(SongSegment)
.filter(SongSegment.song_id == song_id, SongSegment.audio_hash == audio_hash)
.all()
)
existing_segments = {segment.segment_index: segment for segment in segments}
output_segments: list[dict] = []
min_len_ms = int((settings.EMO_MIN_SEGMENT_SECONDS or 0) * 1000)
for idx, plan in enumerate(segment_plan, start=1):
start_sec = plan["start_sec"]
segment_duration = plan["duration_sec"]
if segment_duration <= 0:
continue
segment = existing_segments.get(idx)
if segment and segment.status == "succeeded" and segment.audio_url:
expected_start_ms = int(start_sec * 1000)
expected_duration_ms = int(segment_duration * 1000)
if segment.start_ms == expected_start_ms and segment.duration_ms == expected_duration_ms:
emo_duration_ms = segment.duration_ms
if min_len_ms > 0 and segment.duration_ms < min_len_ms:
emo_duration_ms = min_len_ms
output_segments.append(
{
"id": segment.id,
"segment_index": segment.segment_index,
"audio_url": segment.audio_url,
"duration_ms": segment.duration_ms,
"emo_duration_ms": emo_duration_ms,
}
)
continue
start_ms = int(start_sec * 1000)
duration_ms = int(segment_duration * 1000)
output_path = os.path.join(tmpdir, f"segment_{idx}.mp3")
with SessionLocal() as db:
segment = (
db.query(SongSegment)
.filter(
SongSegment.song_id == song_id,
SongSegment.audio_hash == audio_hash,
SongSegment.segment_index == idx,
)
.with_for_update()
.first()
)
if not segment:
segment = SongSegment(
song_id=song_id,
audio_hash=audio_hash,
segment_index=idx,
start_ms=start_ms,
duration_ms=duration_ms,
audio_url="",
status="running",
)
db.add(segment)
db.flush()
else:
segment.status = "running"
segment.error_msg = None
segment.start_ms = start_ms
segment.duration_ms = duration_ms
db.add(segment)
db.query(SongSegmentVideo).filter(SongSegmentVideo.segment_id == segment.id).delete(
synchronize_session=False
)
db.commit()
try:
_extract_audio_segment(input_path, start_sec, segment_duration, output_path)
with open(output_path, "rb") as file_handle:
segment_bytes = file_handle.read()
object_name = f"song/{song_id}/segments/{audio_hash}_{idx}.mp3"
segment_url = _upload_to_oss(segment_bytes, object_name)
audio_size = os.path.getsize(output_path)
with SessionLocal() as db:
segment = (
db.query(SongSegment)
.filter(
SongSegment.song_id == song_id,
SongSegment.audio_hash == audio_hash,
SongSegment.segment_index == idx,
)
.with_for_update()
.first()
)
if not segment:
raise HTTPException(status_code=500, detail="分段记录丢失")
segment.audio_url = segment_url
segment.audio_size = audio_size
segment.status = "succeeded"
segment.error_msg = None
db.add(segment)
db.flush()
segment_id = segment.id
db.commit()
output_segments.append(
{
"id": segment_id,
"segment_index": idx,
"audio_url": segment_url,
"duration_ms": duration_ms,
"emo_duration_ms": max(duration_ms, min_len_ms) if min_len_ms > 0 else duration_ms,
}
)
except Exception as exc:
with SessionLocal() as db:
segment = (
db.query(SongSegment)
.filter(
SongSegment.song_id == song_id,
SongSegment.audio_hash == audio_hash,
SongSegment.segment_index == idx,
)
.with_for_update()
.first()
)
if segment:
segment.status = "failed"
segment.error_msg = (str(exc) or "生成失败")[:255]
db.add(segment)
db.commit()
raise
output_segments.sort(key=lambda item: item["segment_index"])
return output_segments, audio_hash, duration_sec
def _should_enqueue_task(task_id: int) -> bool:
now = time.time()
with _sing_enqueue_lock:
last = _sing_last_enqueue_at.get(task_id)
if last and now - last < SING_REQUEUE_COOLDOWN_SECONDS:
return False
_sing_last_enqueue_at[task_id] = now
if len(_sing_last_enqueue_at) > 2000:
cutoff = now - 3600
for key in list(_sing_last_enqueue_at.keys()):
if _sing_last_enqueue_at.get(key, 0) < cutoff:
_sing_last_enqueue_at.pop(key, None)
return True
def _enqueue_sing_task(task_id: int):
# 移除入队日志,只在失败时记录
result = sing_task_queue.enqueue_unique(f"sing:{task_id}", _process_sing_task, task_id)
return result
def _next_seq(db: Session, session_id: int) -> int:
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session_id)
.order_by(ChatMessage.seq.desc())
.first()
)
return (last_msg.seq if last_msg and last_msg.seq else 0) + 1
def _mark_task_failed(task_id: int, msg: str):
"""标记任务失败,并更新占位消息。使用独立会话,避免已关闭会话导致的 Detached 异常。"""
with SessionLocal() as db:
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task:
return
payload = task.payload or {}
lover_msg_id = payload.get("lover_message_id")
session_id = payload.get("session_id")
if lover_msg_id:
lover_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == lover_msg_id)
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = f"唱歌视频生成失败:{msg}"
lover_msg.extra = {
**(lover_msg.extra or {}),
"generation_status": "failed",
"error_msg": msg,
"generation_task_id": task_id,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
db.add(lover_msg)
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id)
.with_for_update()
.first()
)
if session:
session.last_message_at = datetime.utcnow()
db.add(session)
task.status = "failed"
task.error_msg = msg[:255]
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
def _mark_task_content_blocked(task_id: int, msg: str):
"""内容安全拦截:直接提示文案,不拼接失败前缀。"""
with SessionLocal() as db:
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task:
return
payload = task.payload or {}
lover_msg_id = payload.get("lover_message_id")
session_id = payload.get("session_id")
if lover_msg_id:
lover_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == lover_msg_id)
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = msg
lover_msg.extra = {
**(lover_msg.extra or {}),
"generation_status": "failed",
"error_msg": msg,
"generation_task_id": task_id,
"content_safety_blocked": True,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
db.add(lover_msg)
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id)
.with_for_update()
.first()
)
if session:
session.last_message_at = datetime.utcnow()
db.add(session)
task.status = "failed"
task.error_msg = msg[:255]
task.payload = {**payload, "content_safety_blocked": True}
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
def _sanitize_resolution(resolution: Optional[str]) -> str:
"""wan2.2-i2v-flash 支持 480P/720P/1080P非法值回落 480P。"""
if not resolution:
return "480P"
upper = str(resolution).upper()
if upper in ("480P", "720P", "1080P"):
return upper
return "480P"
class SongOut(BaseModel):
id: int
title: str
artist: Optional[str] = None
gender: str
audio_url: str
model_config = ConfigDict(from_attributes=True)
class SongListResponse(BaseModel):
songs: List[SongOut]
class SingGenerateIn(BaseModel):
song_id: int = Field(..., description="歌曲IDnf_song_library.id")
class SingGenerateFromLibraryIn(BaseModel):
"""从音乐库生成唱歌视频请求"""
music_id: int = Field(..., description="音乐库IDnf_music_library.id")
class SingTaskStatusOut(BaseModel):
generation_task_id: int
status: str = Field(..., description="pending|running|succeeded|failed")
dashscope_task_id: str = ""
video_url: str = ""
session_id: int = 0
user_message_id: int = 0
lover_message_id: int = 0
error_msg: Optional[str] = None
@router.get("/current", response_model=ApiResponse[Optional[SingTaskStatusOut]])
def get_current_sing_task(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
task = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
GenerationTask.status.in_(["pending", "running"]),
GenerationTask.payload["song_id"].as_integer().isnot(None),
)
.order_by(GenerationTask.id.desc())
.first()
)
if not task:
return success_response(None, msg="暂无进行中的任务")
payload = task.payload or {}
return success_response(
SingTaskStatusOut(
generation_task_id=task.id,
status=task.status,
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
video_url=task.result_url or payload.get("merged_video_url") or "",
session_id=int(payload.get("session_id") or 0),
user_message_id=int(payload.get("user_message_id") or 0),
lover_message_id=int(payload.get("lover_message_id") or 0),
error_msg=task.error_msg,
),
msg="获取成功",
)
@router.get("/songs", response_model=ApiResponse[SongListResponse])
def list_songs_for_lover(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""
返回当前用户恋人性别可用的歌曲列表。
"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
if lover.gender not in ("male", "female"):
raise HTTPException(status_code=400, detail="恋人性别异常,请重新选择性别")
songs = (
db.query(SongLibrary)
.filter(
SongLibrary.gender == lover.gender,
SongLibrary.status.is_(True),
SongLibrary.deletetime.is_(None),
)
.order_by(SongLibrary.weigh.desc(), SongLibrary.id.desc())
.all()
)
if not songs:
raise HTTPException(status_code=404, detail="暂无可用歌曲,请稍后重试")
return success_response(SongListResponse(songs=songs), msg="歌曲列表获取成功")
@router.get("/history", response_model=ApiResponse[List[dict]])
def get_sing_history(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
page: int = 1,
size: int = 20,
):
"""
获取用户的唱歌视频历史记录
"""
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
# 查询已成功生成的视频(优先使用 nf_sing_song_video
offset = (page - 1) * size
videos = (
db.query(SingSongVideo)
.filter(
SingSongVideo.user_id == user.id,
SingSongVideo.lover_id == lover.id,
SingSongVideo.status == "succeeded",
SingSongVideo.merged_video_url.isnot(None),
)
.order_by(SingSongVideo.id.desc())
.offset(offset)
.limit(size)
.all()
)
result: list[dict] = []
seen_urls: set[str] = set()
for video in videos:
# 获取歌曲信息
song = db.query(SongLibrary).filter(SongLibrary.id == video.song_id).first()
song_title = song.title if song else "未知歌曲"
url = _cdnize(video.merged_video_url)
if url:
seen_urls.add(url)
result.append(
{
"id": video.id,
"song_id": video.song_id,
"song_title": song_title,
"video_url": url,
"created_at": video.created_at.isoformat() if video.created_at else None,
}
)
# 兜底:部分情况下任务成功但 nf_sing_song_video 未落库,补查 nf_generation_tasks
if len(result) < size:
remaining = size - len(result)
fallback_tasks = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.lover_id == lover.id,
GenerationTask.task_type == "video",
GenerationTask.status == "succeeded",
GenerationTask.payload["song_id"].as_integer().isnot(None),
GenerationTask.payload["merged_video_url"].as_string().isnot(None),
)
.order_by(GenerationTask.id.desc())
.offset(offset)
.limit(size * 2)
.all()
)
for task in fallback_tasks:
payload = task.payload or {}
song_id = payload.get("song_id")
merged_video_url = payload.get("merged_video_url") or task.result_url
url = _cdnize(merged_video_url) if merged_video_url else ""
if not url or url in seen_urls:
continue
song_title = payload.get("song_title") or "未知歌曲"
if song_id:
song = db.query(SongLibrary).filter(SongLibrary.id == song_id).first()
if song and song.title:
song_title = song.title
result.append(
{
"id": int(task.id),
"song_id": int(song_id) if song_id else 0,
"song_title": song_title,
"video_url": url,
"created_at": task.created_at.isoformat() if task.created_at else None,
}
)
seen_urls.add(url)
remaining -= 1
if remaining <= 0:
break
return success_response(result, msg="获取成功")
@router.post("/retry/{task_id}", response_model=ApiResponse[dict])
def retry_sing_task(
task_id: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
"""手动重试:用于 DashScope 端已成功但本地下载/上传失败导致任务失败的情况。"""
task = (
db.query(GenerationTask)
.filter(
GenerationTask.id == task_id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
GenerationTask.payload["song_id"].as_integer().isnot(None),
)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
payload = task.payload or {}
dash_id = payload.get("dashscope_task_id")
if not dash_id:
# 唱歌任务通常不会在 GenerationTask.payload 中保存 dashscope_task_id分段任务各自保存
# 此时改为重新入队处理,尽量复用已成功的分段,完成补下载/补写记录。
task.status = "pending"
task.error_msg = None
task.payload = {**payload, "manual_retry": True}
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
_enqueue_sing_task(int(task.id))
return success_response({"task_id": int(task.id)}, msg="已触发重新下载")
# 标记手动重试(避免前端重复点击导致并发过多)
task.payload = {**payload, "manual_retry": True}
task.updated_at = datetime.utcnow()
db.add(task)
db.commit()
background_tasks.add_task(_retry_finalize_sing_task, int(task.id))
return success_response({"task_id": int(task.id)}, msg="已触发重新下载")
@router.get("/history/all", response_model=ApiResponse[List[dict]])
def get_sing_history_all(
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
page: int = 1,
size: int = 20,
):
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人未找到")
offset = (page - 1) * size
tasks = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.lover_id == lover.id,
GenerationTask.task_type == "video",
GenerationTask.payload["song_id"].as_integer().isnot(None),
)
.order_by(GenerationTask.id.desc())
.offset(offset)
.limit(size)
.all()
)
result: list[dict] = []
for task in tasks:
payload = task.payload or {}
song_id = payload.get("song_id")
merged_video_url = payload.get("merged_video_url") or task.result_url
result.append(
{
"id": int(task.id),
"song_id": int(song_id) if song_id else 0,
"song_title": payload.get("song_title") or "未知歌曲",
"status": task.status,
"video_url": _cdnize(merged_video_url) or "",
"error_msg": task.error_msg,
"created_at": task.created_at.isoformat() if task.created_at else None,
}
)
return success_response(result, msg="获取成功")
def _get_or_create_session(db: Session, user: AuthedUser, lover: Lover, session_id: Optional[int]) -> ChatSession:
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="会话不存在")
return session
active_sessions = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id, ChatSession.lover_id == lover.id, ChatSession.status == "active")
.with_for_update()
.order_by(ChatSession.created_at.desc())
.all()
)
if active_sessions:
primary = active_sessions[0]
for extra in active_sessions[1:]:
if extra.status == "active":
extra.status = "archived"
db.add(extra)
return primary
now = datetime.utcnow()
session = ChatSession(
user_id=user.id,
lover_id=lover.id,
model=settings.LLM_MODEL or "qwen-flash",
status="active",
last_message_at=now,
created_at=now,
updated_at=now,
inner_voice_enabled=False,
)
db.add(session)
db.flush()
return session
def _process_sing_task(task_id: int):
"""
后台处理唱歌视频生成任务:分段音频 -> EMO 逐段生成 -> 拼接整曲。
"""
logger.info(f"开始处理唱歌任务 {task_id}")
song_title: str = ""
image_url: str = ""
audio_url: str = ""
session_id: Optional[int] = None
user_message_id: Optional[int] = None
lover_message_id: Optional[int] = None
user_id: Optional[int] = None
lover_id: Optional[int] = None
song_id: Optional[int] = None
image_hash: str = ""
merged_video_url: str = ""
merge_id: Optional[int] = None
ratio: str = EMO_RATIO
style_level: str = EMO_STYLE_LEVEL
face_bbox: Optional[list] = None
ext_bbox: Optional[list] = None
audio_hash_hint: Optional[str] = None
duration_sec_hint: Optional[int] = None
try:
logger.info(f"任务 {task_id}: 开始数据库查询")
db = SessionLocal()
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
logger.info(f"任务 {task_id}: 查询到任务,状态={task.status if task else 'None'}")
if not task or task.status in ("succeeded", "failed"):
logger.warning(f"任务 {task_id}: 任务不存在或已完成/失败,退出处理")
db.rollback()
return
logger.info(f"任务 {task_id}: 开始提取 payload 数据")
user_id = task.user_id
lover_id = task.lover_id
payload = task.payload or {}
song_id = payload.get("song_id")
song_title = payload.get("song_title") or ""
image_url = payload.get("image_url") or ""
audio_url = payload.get("audio_url") or ""
session_id = payload.get("session_id")
user_message_id = payload.get("user_message_id")
lover_message_id = payload.get("lover_message_id")
ratio = payload.get("ratio") or EMO_RATIO
style_level = payload.get("style_level") or EMO_STYLE_LEVEL
face_bbox = payload.get("face_bbox")
ext_bbox = payload.get("ext_bbox")
image_hash = payload.get("image_hash") or _hash_text(image_url or "")
lover = db.query(Lover).filter(Lover.id == lover_id).first()
user_row = db.query(User).filter(User.id == user_id).with_for_update().first()
song = None
if song_id:
song = (
db.query(SongLibrary)
.filter(
SongLibrary.id == song_id,
SongLibrary.status.is_(True),
SongLibrary.deletetime.is_(None),
)
.first()
)
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在,请重新创建")
if not lover.image_url:
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
if not song:
raise HTTPException(status_code=404, detail="歌曲未找到或已下架")
if song.gender != lover.gender:
raise HTTPException(status_code=400, detail="歌曲版本性别与恋人不匹配")
song_id = song.id
if (user_row.video_gen_remaining or 0) <= 0:
raise HTTPException(status_code=400, detail="视频生成次数不足")
image_url = _cdnize(image_url or lover.image_url)
audio_url = _cdnize(audio_url or song.audio_url)
if not audio_url:
raise HTTPException(status_code=400, detail="音频地址不可用")
if not image_url:
raise HTTPException(status_code=400, detail="形象地址不可用")
if not image_hash:
image_hash = _hash_text(image_url)
if not song_title:
song_title = song.title or "点播歌曲"
audio_hash_hint = song.audio_hash
duration_sec_hint = song.duration_sec
if not face_bbox or not ext_bbox:
detect = _ensure_emo_detect_cache(db, lover_id, image_url, image_hash, ratio)
if not detect.check_pass:
raise HTTPException(status_code=400, detail="恋人形象未通过EMO检测")
face_bbox = detect.face_bbox
ext_bbox = detect.ext_bbox
if not face_bbox or not ext_bbox:
raise HTTPException(status_code=502, detail="EMO检测返回缺少人脸或动态区域")
task.status = "running"
task.updated_at = datetime.utcnow()
task.payload = {
**payload,
"song_id": song.id,
"song_title": song_title,
"image_url": image_url,
"audio_url": audio_url,
"image_hash": image_hash,
"ratio": ratio,
"style_level": style_level,
"face_bbox": face_bbox,
"ext_bbox": ext_bbox,
}
db.add(task)
db.commit()
except HTTPException as exc:
try:
_mark_task_failed(task_id, str(exc.detail) if hasattr(exc, "detail") else str(exc))
except Exception:
pass
finally:
try:
db.close()
except Exception:
pass
return
except Exception as exc:
try:
_mark_task_failed(task_id, str(exc)[:255])
except Exception:
pass
finally:
try:
db.close()
except Exception:
pass
return
finally:
try:
db.close()
except Exception:
pass
try:
logger.info(f"任务 {task_id}: 开始音频分段处理")
segments, audio_hash, duration_sec = _ensure_song_segments(
song_id,
audio_url,
audio_hash_hint,
duration_sec_hint,
)
logger.info(f"任务 {task_id}: 音频分段完成,共 {len(segments)} 段,时长 {duration_sec}")
with SessionLocal() as db:
task_row = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if task_row:
task_row.payload = {
**(task_row.payload or {}),
"audio_hash": audio_hash,
"duration_sec": duration_sec,
"segment_count": len(segments),
}
task_row.updated_at = datetime.utcnow()
db.add(task_row)
db.commit()
content_safety_blocked = False
with SessionLocal() as db:
cached_merge = (
db.query(SingSongVideo)
.filter(
SingSongVideo.lover_id == lover_id,
SingSongVideo.song_id == song_id,
SingSongVideo.audio_hash == audio_hash,
SingSongVideo.ratio == ratio,
SingSongVideo.style_level == style_level,
SingSongVideo.status == "succeeded",
)
.order_by(SingSongVideo.id.desc())
.first()
)
if cached_merge and cached_merge.merged_video_url:
merged_video_url = cached_merge.merged_video_url
merge_id = cached_merge.id
content_safety_blocked = _is_content_safety_error(cached_merge.error_msg)
if not merged_video_url:
logger.info(f"任务 {task_id}: 开始生成分段视频,共 {len(segments)}")
segment_video_urls: list[tuple[int, str]] = []
content_safety_triggered = False
for segment in segments:
segment_id = segment["id"]
segment_index = segment["segment_index"]
segment_audio_url = segment.get("audio_url") or ""
segment_duration_ms = int(segment.get("duration_ms") or 0)
emo_duration_ms = int(segment.get("emo_duration_ms") or segment_duration_ms)
logger.info(f"任务 {task_id}: 处理第 {segment_index + 1}/{len(segments)} 段视频")
existing_running = False
segment_video_id = None
segment_video_url = ""
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(
SongSegmentVideo.segment_id == segment_id,
SongSegmentVideo.image_hash == image_hash,
SongSegmentVideo.style_level == style_level,
SongSegmentVideo.model == EMO_MODEL,
)
.with_for_update()
.first()
)
if segment_video and segment_video.status == "succeeded" and segment_video.video_url:
segment_video_url = segment_video.video_url
segment_video_id = segment_video.id
elif segment_video and segment_video.status == "running":
existing_running = True
segment_video_id = segment_video.id
elif segment_video and segment_video.status == "failed" and _is_content_safety_error(segment_video.error_msg):
content_safety_triggered = True
else:
if not segment_video:
segment_video = SongSegmentVideo(
user_id=user_id,
lover_id=lover_id,
song_id=song_id,
segment_id=segment_id,
image_hash=image_hash,
model=EMO_MODEL,
ratio=ratio,
style_level=style_level,
status="running",
)
db.add(segment_video)
db.flush()
else:
segment_video.status = "running"
segment_video.error_msg = None
db.add(segment_video)
db.commit()
segment_video_id = segment_video.id
if content_safety_triggered:
break
if not segment_video_url and existing_running and segment_video_id:
waited = _wait_for_segment_video(segment_video_id, EMO_TASK_TIMEOUT_SECONDS)
if waited and waited.status == "succeeded" and waited.video_url:
segment_video_url = waited.video_url
elif waited and waited.status == "failed" and _is_content_safety_error(waited.error_msg):
content_safety_triggered = True
if content_safety_triggered:
break
if not segment_video_url:
try:
with _semaphore_guard(_emo_task_semaphore):
emo_audio_url = segment_audio_url
if (
emo_duration_ms > segment_duration_ms
and segment_duration_ms > 0
and segment_audio_url
):
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, f"segment_{segment_index}.mp3")
padded_path = os.path.join(tmpdir, f"segment_{segment_index}_emo.mp3")
_download_to_path(segment_audio_url, input_path)
pad_sec = (emo_duration_ms - segment_duration_ms) / 1000.0
target_sec = emo_duration_ms / 1000.0
_pad_audio_segment(input_path, pad_sec, target_sec, padded_path)
with open(padded_path, "rb") as file_handle:
padded_bytes = file_handle.read()
object_name = (
f"song/{song_id}/segments/"
f"{audio_hash}_{segment_index}_emo.mp3"
)
emo_audio_url = _upload_to_oss(padded_bytes, object_name)
dash_task_id = _submit_emo_video(
image_url=image_url,
audio_url=emo_audio_url,
face_bbox=face_bbox or [],
ext_bbox=ext_bbox or [],
style_level=style_level,
)
logger.info(f"任务 {task_id}: 第 {segment_index + 1} 段已提交到 DashScopetask_id={dash_task_id}")
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video:
segment_video.dashscope_task_id = dash_task_id
segment_video.status = "running"
segment_video.error_msg = None
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
dash_video_url = _poll_video_url(dash_task_id, EMO_TASK_TIMEOUT_SECONDS)
logger.info(f"任务 {task_id}: 第 {segment_index + 1} 段视频生成完成URL={dash_video_url[:100]}...")
video_bytes = _download_binary(dash_video_url)
if emo_duration_ms > segment_duration_ms and segment_duration_ms > 0:
video_bytes = _trim_video_bytes(video_bytes, segment_duration_ms / 1000.0)
object_name = (
f"lover/{lover_id}/sing/segments/"
f"{song_id}_{image_hash}_{style_level}_{segment_index}.mp4"
)
segment_video_url = _upload_to_oss(video_bytes, object_name)
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video:
segment_video.video_url = segment_video_url
segment_video.status = "succeeded"
segment_video.error_msg = None
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
except Exception as exc:
with SessionLocal() as db:
segment_video = (
db.query(SongSegmentVideo)
.filter(SongSegmentVideo.id == segment_video_id)
.with_for_update()
.first()
)
if segment_video:
segment_video.status = "failed"
segment_video.error_msg = (_extract_error_text(exc) or "生成失败")[:255]
segment_video.updated_at = datetime.utcnow()
db.add(segment_video)
db.commit()
if _is_content_safety_error(_extract_error_text(exc)):
content_safety_triggered = True
break
raise
if content_safety_triggered:
break
if segment_video_url:
segment_video_urls.append((segment_index, segment_video_url))
if content_safety_triggered and not segment_video_urls:
_mark_task_content_blocked(task_id, EMO_CONTENT_SAFETY_MESSAGE)
return
if content_safety_triggered:
content_safety_blocked = True
segment_video_urls.sort(key=lambda item: item[0])
with tempfile.TemporaryDirectory() as tmpdir:
video_paths = []
for idx, url in segment_video_urls:
local_path = os.path.join(tmpdir, f"segment_{idx}.mp4")
_download_to_path(url, local_path)
video_paths.append(local_path)
output_path = os.path.join(tmpdir, "merged.mp4")
_concat_video_files(video_paths, output_path)
with open(output_path, "rb") as file_handle:
merged_bytes = file_handle.read()
object_name = f"lover/{lover_id}/sing/{int(time.time())}_{song_id}.mp4"
merged_video_url = _upload_to_oss(merged_bytes, object_name)
with SessionLocal() as db:
task_row = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task_row:
return
payload = task_row.payload or {}
if not session_id:
session_id = payload.get("session_id")
if not user_message_id:
user_message_id = payload.get("user_message_id")
if not lover_message_id:
lover_message_id = payload.get("lover_message_id")
merge_row = (
db.query(SingSongVideo)
.filter(
SingSongVideo.lover_id == lover_id,
SingSongVideo.song_id == song_id,
SingSongVideo.audio_hash == audio_hash,
SingSongVideo.image_hash == image_hash,
SingSongVideo.ratio == ratio,
SingSongVideo.style_level == style_level,
)
.with_for_update()
.first()
)
if not merge_row:
merge_row = SingSongVideo(
user_id=user_id,
lover_id=lover_id,
song_id=song_id,
base_video_id=None,
audio_url=audio_url,
audio_hash=audio_hash,
image_hash=image_hash,
ratio=ratio,
style_level=style_level,
merged_video_url=merged_video_url,
status="succeeded",
error_msg=EMO_CONTENT_SAFETY_CODE if content_safety_blocked else None,
generation_task_id=task_row.id,
)
db.add(merge_row)
db.flush()
else:
merge_row.user_id = user_id
merge_row.audio_url = audio_url
merge_row.audio_hash = audio_hash
merge_row.image_hash = image_hash
merge_row.ratio = ratio
merge_row.style_level = style_level
merge_row.merged_video_url = merged_video_url
merge_row.status = "succeeded"
merge_row.error_msg = EMO_CONTENT_SAFETY_CODE if content_safety_blocked else None
merge_row.generation_task_id = task_row.id
db.add(merge_row)
merge_id = merge_row.id
session = None
if session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user_id)
.with_for_update()
.first()
)
else:
session = _get_or_create_session(db, AuthedUser(id=user_id), Lover(id=lover_id), None)
now = datetime.utcnow()
if not user_message_id:
user_msg = ChatMessage(
session_id=session.id,
user_id=user_id,
lover_id=lover_id,
role="user",
content_type="text",
content=song_title or "点播歌曲",
seq=_next_seq(db, session.id),
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
else:
user_msg = db.query(ChatMessage).filter(ChatMessage.id == user_message_id).first()
if not user_msg:
user_msg = ChatMessage(
session_id=session.id,
user_id=user_id,
lover_id=lover_id,
role="user",
content_type="text",
content=song_title or "点播歌曲",
seq=_next_seq(db, session.id),
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
lover_content = _build_sing_message_content(merged_video_url, content_safety_blocked)
if not lover_message_id:
lover_msg = ChatMessage(
session_id=session.id,
user_id=user_id,
lover_id=lover_id,
role="lover",
content_type="text",
content=lover_content,
seq=_next_seq(db, session.id),
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"video_url": merged_video_url,
"generation_task_id": task_row.id,
"song_title": song_title,
"generation_status": "succeeded",
"content_safety_blocked": content_safety_blocked,
},
)
db.add(lover_msg)
db.flush()
else:
lover_msg = (
db.query(ChatMessage)
.filter(ChatMessage.id == lover_message_id)
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = lover_content
lover_msg.extra = {
**(lover_msg.extra or {}),
"video_url": merged_video_url,
"generation_status": "succeeded",
"content_safety_blocked": content_safety_blocked,
}
db.add(lover_msg)
else:
lover_msg = ChatMessage(
session_id=session.id,
user_id=user_id,
lover_id=lover_id,
role="lover",
content_type="text",
content=lover_content,
seq=_next_seq(db, session.id),
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"video_url": merged_video_url,
"generation_task_id": task_row.id,
"song_title": song_title,
"generation_status": "succeeded",
"content_safety_blocked": content_safety_blocked,
},
)
db.add(lover_msg)
db.flush()
lover_msg.extra = {
**(lover_msg.extra or {}),
"video_url": merged_video_url,
"generation_status": "succeeded",
"generation_task_id": task_row.id,
"song_title": song_title,
"content_safety_blocked": content_safety_blocked,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
db.add(lover_msg)
already_deducted = (task_row.payload or {}).get("deducted")
user_row = db.query(User).filter(User.id == user_id).with_for_update().first()
remaining = user_row.video_gen_remaining if user_row else 0
if user_row and remaining > 0 and not already_deducted:
user_row.video_gen_remaining = remaining - 1
db.add(user_row)
session.last_message_at = datetime.utcnow()
db.add(session)
task_row.status = "succeeded"
task_row.result_url = merged_video_url
task_row.payload = {
**(task_row.payload or {}),
"merged_video_url": merged_video_url,
"merge_id": merge_id,
"deducted": True,
"content_safety_blocked": content_safety_blocked,
"session_id": session.id,
"user_message_id": user_msg.id,
"lover_message_id": lover_msg.id,
}
task_row.updated_at = datetime.utcnow()
db.add(task_row)
db.commit()
except HTTPException as exc:
logger.error(f"任务 {task_id} 处理失败 (HTTPException): {exc.detail if hasattr(exc, 'detail') else str(exc)}")
try:
_mark_task_failed(task_id, str(exc.detail) if hasattr(exc, "detail") else str(exc))
except Exception as e2:
logger.exception(f"标记任务 {task_id} 失败时出错: {e2}")
except Exception as exc:
logger.exception(f"任务 {task_id} 处理失败 (Exception): {exc}")
try:
_mark_task_failed(task_id, str(exc)[:255])
except Exception as e2:
logger.exception(f"标记任务 {task_id} 失败时出错: {e2}")
@router.post("/generate", response_model=ApiResponse[SingTaskStatusOut])
def generate_sing_video(
payload: SingGenerateIn,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
logger.info(f"🎤 收到唱歌生成请求: user_id={user.id}, song_id={payload.song_id}")
# 原有代码...
lover = db.query(Lover).filter(Lover.user_id == user.id).first()
if not lover:
raise HTTPException(status_code=404, detail="恋人不存在,请先完成创建流程")
if not lover.image_url:
raise HTTPException(status_code=400, detail="请先生成并确认恋人形象")
song = (
db.query(SongLibrary)
.filter(
SongLibrary.id == payload.song_id,
SongLibrary.status.is_(True),
SongLibrary.deletetime.is_(None),
)
.first()
)
if not song:
raise HTTPException(status_code=404, detail="歌曲未找到或已下架")
if song.gender != lover.gender:
raise HTTPException(status_code=400, detail="歌曲版本性别与恋人不匹配")
user_row = (
db.query(User)
.filter(User.id == user.id)
.with_for_update()
.first()
)
if not user_row:
raise HTTPException(status_code=404, detail="用户不存在")
# 检查并重置 VIP 用户的视频生成次数
_check_and_reset_vip_video_gen(user_row, db)
if (user_row.video_gen_remaining or 0) <= 0:
raise HTTPException(status_code=400, detail="视频生成次数不足")
pending_task = (
db.query(GenerationTask)
.filter(
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
GenerationTask.status.in_(["pending", "running"]),
)
.first()
)
if pending_task:
raise HTTPException(status_code=409, detail="已有视频生成任务进行中,请稍后再试")
if not song.audio_url:
raise HTTPException(status_code=400, detail="该歌曲无可用音频")
audio_url = _cdnize(song.audio_url)
if not audio_url:
raise HTTPException(status_code=400, detail="音频地址不可用")
audio_url_hash = _hash_text(audio_url)
audio_hash = song.audio_hash or audio_url_hash
ratio = EMO_RATIO
style_level = EMO_STYLE_LEVEL
audio_hash_candidates = [audio_hash]
if audio_url_hash not in audio_hash_candidates:
audio_hash_candidates.append(audio_url_hash)
cached_merge = (
db.query(SingSongVideo)
.filter(
SingSongVideo.lover_id == lover.id,
SingSongVideo.song_id == song.id,
SingSongVideo.audio_hash.in_(audio_hash_candidates),
SingSongVideo.ratio == ratio,
SingSongVideo.style_level == style_level,
SingSongVideo.status == "succeeded",
)
.order_by(SingSongVideo.id.desc())
.first()
)
cached_merge_url = cached_merge.merged_video_url if cached_merge else ""
cached_merge_id = cached_merge.id if cached_merge else None
cached_merge_audio_hash = cached_merge.audio_hash if cached_merge else None
cached_content_safety = _is_content_safety_error(cached_merge.error_msg) if cached_merge else False
image_url = _cdnize(lover.image_url)
if not image_url:
raise HTTPException(status_code=400, detail="形象地址不可用")
image_hash = _hash_text(image_url)
idem_key_src = f"sing:{user.id}:{song.id}:{audio_hash}:{ratio}:{style_level}"
idem_key = hashlib.sha256(idem_key_src.encode("utf-8")).hexdigest()
if cached_merge_url:
task = GenerationTask(
user_id=user.id,
lover_id=lover.id,
task_type="video",
status="succeeded",
idempotency_key=idem_key,
result_url=cached_merge_url,
payload={
"image_url": image_url,
"audio_url": audio_url,
"song_id": song.id,
"song_title": song.title,
"image_hash": image_hash,
"audio_hash": cached_merge_audio_hash or audio_hash,
"ratio": ratio,
"style_level": style_level,
"merged_video_url": cached_merge_url,
"merge_id": cached_merge_id,
"deducted": True,
"content_safety_blocked": cached_content_safety,
},
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(task)
try:
db.flush()
except IntegrityError:
db.rollback()
existing = (
db.query(GenerationTask)
.filter(GenerationTask.idempotency_key == idem_key)
.first()
)
if existing and existing.status in ("pending", "running"):
raise HTTPException(status_code=409, detail="已有视频生成任务进行中,请稍后再试")
retry_key = hashlib.sha256(f"{idem_key}:{time.time()}".encode()).hexdigest()
task.idempotency_key = retry_key
db.add(task)
db.flush()
session = _get_or_create_session(db, user, lover, None)
now = datetime.utcnow()
next_seq = _next_seq(db, session.id)
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="text",
content=song.title or "点播歌曲",
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
lover_content = _build_sing_message_content(cached_merge_url, cached_content_safety)
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content=lover_content,
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"generation_task_id": task.id,
"generation_status": "succeeded",
"song_title": song.title,
"video_url": cached_merge_url,
"content_safety_blocked": cached_content_safety,
},
tts_status="pending",
)
db.add(lover_msg)
db.flush()
session.last_message_at = datetime.utcnow()
db.add(session)
task.payload = {
**(task.payload or {}),
"session_id": session.id,
"user_message_id": user_msg.id,
"lover_message_id": lover_msg.id,
}
db.add(task)
remaining = user_row.video_gen_remaining if user_row else 0
if user_row and remaining > 0:
user_row.video_gen_remaining = remaining - 1
db.add(user_row)
db.commit()
return success_response(
SingTaskStatusOut(
generation_task_id=task.id,
status="succeeded",
dashscope_task_id="",
video_url=cached_merge_url,
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
error_msg=None,
),
msg="视频生成成功",
)
detect = _ensure_emo_detect_cache(db, lover.id, image_url, image_hash, ratio)
if not detect.check_pass:
raise HTTPException(status_code=400, detail="恋人形象未通过EMO检测")
if not detect.face_bbox or not detect.ext_bbox:
raise HTTPException(status_code=502, detail="EMO检测返回缺少人脸或动态区域")
task = GenerationTask(
user_id=user.id,
lover_id=lover.id,
task_type="video",
status="pending",
idempotency_key=idem_key,
payload={
"image_url": image_url,
"audio_url": audio_url,
"song_id": song.id,
"song_title": song.title,
"image_hash": image_hash,
"audio_hash": audio_hash,
"ratio": ratio,
"style_level": style_level,
"face_bbox": detect.face_bbox,
"ext_bbox": detect.ext_bbox,
},
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(task)
try:
db.flush()
except IntegrityError:
db.rollback()
existing = (
db.query(GenerationTask)
.filter(GenerationTask.idempotency_key == idem_key)
.first()
)
if existing and existing.status in ("pending", "running"):
raise HTTPException(status_code=409, detail="已提交相同的唱歌视频生成任务,请稍后查看结果")
retry_key = hashlib.sha256(f"{idem_key}:{time.time()}".encode()).hexdigest()
task.idempotency_key = retry_key
db.add(task)
db.flush()
# 预留会话与消息序号,便于生成完成后不会被后续消息“插队”
session = _get_or_create_session(db, user, lover, None)
now = datetime.utcnow()
next_seq = _next_seq(db, session.id)
user_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="user",
content_type="text",
content=song.title or "点播歌曲",
seq=next_seq,
created_at=now,
model=settings.LLM_MODEL or "qwen-flash",
)
db.add(user_msg)
db.flush()
lover_msg = ChatMessage(
session_id=session.id,
user_id=user.id,
lover_id=lover.id,
role="lover",
content_type="text",
content="正在为你生成唱歌视频,完成后会自动更新此消息",
seq=next_seq + 1,
created_at=datetime.utcnow(),
model=settings.LLM_MODEL or "qwen-flash",
extra={
"generation_task_id": task.id,
"generation_status": "pending",
"song_title": song.title,
},
tts_status="pending",
)
db.add(lover_msg)
db.flush()
session.last_message_at = datetime.utcnow()
db.add(session)
task.payload = {
**(task.payload or {}),
"session_id": session.id,
"user_message_id": user_msg.id,
"lover_message_id": lover_msg.id,
}
db.add(task)
db.commit()
_enqueue_sing_task(task.id)
return success_response(
SingTaskStatusOut(
generation_task_id=task.id,
status="pending",
dashscope_task_id="",
video_url="",
session_id=session.id,
user_message_id=user_msg.id,
lover_message_id=lover_msg.id,
error_msg=None,
),
msg="视频生成任务已提交,正在生成",
)
def _ensure_sing_history_record(db: Session, task: GenerationTask) -> None:
"""确保 nf_sing_song_video 中存在该任务对应的成功记录(用于历史列表)。"""
if not task or task.status != "succeeded":
return
payload = task.payload or {}
song_id = payload.get("song_id")
merged_video_url = payload.get("merged_video_url") or task.result_url
if not song_id or not merged_video_url:
return
existing = (
db.query(SingSongVideo)
.filter(
SingSongVideo.generation_task_id == task.id,
SingSongVideo.user_id == task.user_id,
)
.first()
)
if existing and existing.status == "succeeded" and existing.merged_video_url:
return
audio_url = payload.get("audio_url") or ""
audio_hash = payload.get("audio_hash") or ( _hash_text(audio_url) if audio_url else "" )
image_hash = payload.get("image_hash") or ""
ratio = payload.get("ratio") or EMO_RATIO
style_level = payload.get("style_level") or EMO_STYLE_LEVEL
if not existing:
existing = SingSongVideo(
user_id=task.user_id,
lover_id=task.lover_id or 0,
song_id=int(song_id),
audio_url=audio_url or "",
audio_hash=(audio_hash or "")[:64],
image_hash=(image_hash or "")[:64] if image_hash else None,
ratio=ratio,
style_level=style_level,
merged_video_url=merged_video_url,
status="succeeded",
error_msg=None,
generation_task_id=task.id,
created_at=task.created_at or datetime.utcnow(),
updated_at=datetime.utcnow(),
)
else:
existing.song_id = int(song_id)
existing.merged_video_url = merged_video_url
existing.status = "succeeded"
existing.error_msg = None
existing.updated_at = datetime.utcnow()
existing.generation_task_id = task.id
if task.lover_id:
existing.lover_id = task.lover_id
if audio_url:
existing.audio_url = audio_url
if audio_hash:
existing.audio_hash = (audio_hash or "")[:64]
if image_hash:
existing.image_hash = (image_hash or "")[:64]
existing.ratio = ratio
existing.style_level = style_level
db.add(existing)
def _retry_finalize_sing_task(task_id: int) -> None:
"""任务失败但 DashScope 端已成功时,尝试重新下载视频并落库(自愈)。"""
try:
with SessionLocal() as db:
task = (
db.query(GenerationTask)
.filter(GenerationTask.id == task_id)
.with_for_update()
.first()
)
if not task:
return
payload = task.payload or {}
dash_id = payload.get("dashscope_task_id")
if not dash_id:
return
status, dash_video_url, error_msg = _query_dashscope_task_status(dash_id)
if status != "SUCCEEDED" or not dash_video_url:
return
video_bytes = _download_binary(dash_video_url)
object_name = (
f"lover/{task.lover_id}/sing/"
f"{int(time.time())}_{payload.get('song_id') or 'unknown'}.mp4"
)
merged_video_url = _upload_to_oss(video_bytes, object_name)
task.status = "succeeded"
task.result_url = merged_video_url
task.error_msg = None
task.payload = {
**payload,
"merged_video_url": merged_video_url,
"retry_finalized": True,
"dashscope_error": error_msg,
}
task.updated_at = datetime.utcnow()
db.add(task)
_ensure_sing_history_record(db, task)
db.commit()
except Exception:
return
@router.get("/generate/{task_id}", response_model=ApiResponse[SingTaskStatusOut])
def get_sing_task(
task_id: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: AuthedUser = Depends(get_current_user),
):
task = (
db.query(GenerationTask)
.filter(
GenerationTask.id == task_id,
GenerationTask.user_id == user.id,
GenerationTask.task_type == "video",
)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
status_msg_map = {
"pending": "视频生成中",
"running": "视频生成中",
"succeeded": "视频生成成功",
"failed": "视频生成失败",
}
resp_msg = status_msg_map.get(task.status or "", "查询成功")
if task.status in ("pending", "running"):
_enqueue_sing_task(task.id)
payload = task.payload or {}
_backfill_running_segments(payload)
merge_id = payload.get("merge_id")
if merge_id:
with SessionLocal() as tmp:
merge = tmp.query(SingSongVideo).filter(SingSongVideo.id == merge_id).first()
if merge and merge.status == "succeeded" and merge.merged_video_url:
current = (
tmp.query(GenerationTask)
.filter(GenerationTask.id == task.id, GenerationTask.user_id == user.id)
.with_for_update()
.first()
)
if current:
content_safety_blocked = _is_content_safety_error(merge.error_msg) or bool(
(current.payload or {}).get("content_safety_blocked")
)
current.status = "succeeded"
current.result_url = merge.merged_video_url
current.payload = {
**(current.payload or {}),
"merged_video_url": merge.merged_video_url,
"content_safety_blocked": content_safety_blocked,
}
current.updated_at = datetime.utcnow()
try:
lover_msg_id = (current.payload or {}).get("lover_message_id")
session_id = (current.payload or {}).get("session_id")
if lover_msg_id:
lover_msg = (
tmp.query(ChatMessage)
.filter(ChatMessage.id == lover_msg_id)
.with_for_update()
.first()
)
if lover_msg:
lover_msg.content = _build_sing_message_content(
merge.merged_video_url, content_safety_blocked
)
lover_msg.extra = {
**(lover_msg.extra or {}),
"video_url": merge.merged_video_url,
"generation_status": "succeeded",
"content_safety_blocked": content_safety_blocked,
}
lover_msg.tts_status = lover_msg.tts_status or "pending"
tmp.add(lover_msg)
if session_id:
session = (
tmp.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
.with_for_update()
.first()
)
if session:
session.last_message_at = datetime.utcnow()
tmp.add(session)
if not (current.payload or {}).get("deducted"):
user_row = (
tmp.query(User)
.filter(User.id == current.user_id)
.with_for_update()
.first()
)
if user_row:
remaining = user_row.video_gen_remaining or 0
if remaining > 0:
user_row.video_gen_remaining = remaining - 1
current.payload = {**(current.payload or {}), "deducted": True}
tmp.add(user_row)
except Exception:
pass
tmp.add(current)
tmp.commit()
task = current
# 自愈:成功任务但历史缺记录,补写 nf_sing_song_video
if task.status == "succeeded":
try:
_ensure_sing_history_record(db, task)
db.commit()
except Exception:
db.rollback()
# 自愈:失败任务但 DashScope 可能已成功(下载/上传失败导致),后台重试一次
if task.status == "failed":
payload = task.payload or {}
if payload.get("dashscope_task_id") and not payload.get("retry_finalized"):
background_tasks.add_task(_retry_finalize_sing_task, int(task.id))
resp_msg = status_msg_map.get(task.status or "", resp_msg)
payload = task.payload or {}
return success_response(
SingTaskStatusOut(
generation_task_id=task.id,
status=task.status,
dashscope_task_id=str(payload.get("dashscope_task_id") or ""),
video_url=task.result_url or payload.get("merged_video_url") or "",
session_id=int(payload.get("session_id") or 0),
user_message_id=int(payload.get("user_message_id") or 0),
lover_message_id=int(payload.get("lover_message_id") or 0),
error_msg=task.error_msg,
),
msg=resp_msg,
)