84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
|
|
"""
|
||
|
|
CosyVoice voice enrollment helper for cosyvoice-v2.
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
from typing import Any, Dict
|
||
|
|
|
||
|
|
import dashscope
|
||
|
|
from dashscope.audio.tts_v2 import VoiceEnrollmentService
|
||
|
|
|
||
|
|
from .config import settings
|
||
|
|
|
||
|
|
DEFAULT_TARGET_MODEL = "cosyvoice-v2"
|
||
|
|
|
||
|
|
|
||
|
|
def _get_api_key() -> str:
|
||
|
|
api_key = settings.DASHSCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY", "")
|
||
|
|
if not api_key:
|
||
|
|
raise RuntimeError("DASHSCOPE_API_KEY is not set")
|
||
|
|
return api_key
|
||
|
|
|
||
|
|
|
||
|
|
def create_voice_from_url(
|
||
|
|
audio_url: str,
|
||
|
|
prefix: str,
|
||
|
|
target_model: str = DEFAULT_TARGET_MODEL,
|
||
|
|
) -> str:
|
||
|
|
"""Create a cloned voice and return voice_id."""
|
||
|
|
dashscope.api_key = _get_api_key()
|
||
|
|
service = VoiceEnrollmentService()
|
||
|
|
return service.create_voice(
|
||
|
|
target_model=target_model,
|
||
|
|
prefix=prefix,
|
||
|
|
url=audio_url,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def query_voice(voice_id: str) -> Dict[str, Any]:
|
||
|
|
"""Query voice status/details by voice_id."""
|
||
|
|
dashscope.api_key = _get_api_key()
|
||
|
|
service = VoiceEnrollmentService()
|
||
|
|
return service.query_voice(voice_id=voice_id)
|
||
|
|
|
||
|
|
|
||
|
|
def wait_voice_ready(
|
||
|
|
voice_id: str,
|
||
|
|
*,
|
||
|
|
timeout_sec: int = 300,
|
||
|
|
poll_interval: int = 10,
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""Poll until voice status becomes OK or UNDEPLOYED, or timeout."""
|
||
|
|
deadline = time.time() + timeout_sec
|
||
|
|
last: Dict[str, Any] = {}
|
||
|
|
while time.time() < deadline:
|
||
|
|
last = query_voice(voice_id)
|
||
|
|
status = (last or {}).get("status")
|
||
|
|
if status in ("OK", "UNDEPLOYED"):
|
||
|
|
return last
|
||
|
|
time.sleep(poll_interval)
|
||
|
|
raise TimeoutError("Voice is not ready before timeout")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="CosyVoice clone helper")
|
||
|
|
parser.add_argument("--url", required=True, help="Public audio URL")
|
||
|
|
parser.add_argument("--prefix", required=True, help="Voice name prefix (<=10 chars)")
|
||
|
|
parser.add_argument(
|
||
|
|
"--target-model",
|
||
|
|
default=DEFAULT_TARGET_MODEL,
|
||
|
|
help="Target TTS model for the cloned voice",
|
||
|
|
)
|
||
|
|
parser.add_argument("--timeout", type=int, default=300)
|
||
|
|
parser.add_argument("--poll", type=int, default=10)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
vid = create_voice_from_url(args.url, args.prefix, args.target_model)
|
||
|
|
print(f"voice_id: {vid}")
|
||
|
|
info = wait_voice_ready(vid, timeout_sec=args.timeout, poll_interval=args.poll)
|
||
|
|
print(f"status: {info.get('status')}")
|