Ai_GirlFriend/lover/llm.py

134 lines
3.9 KiB
Python
Raw Permalink Normal View History

2026-01-31 19:15:41 +08:00
"""轻量封装通义千问 (DashScope) 文本生成。
目前主要用于聊天初始化/对话默认模型 qwen-flash可通过 .env 覆盖
"""
from typing import Dict, Iterable, List, Optional
import dashscope
from dashscope import Generation
from fastapi import HTTPException
from .config import settings
class LLMResult:
"""非流式响应结果。"""
def __init__(self, content: str, usage: Optional[dict]):
self.content = content
self.usage = usage or {}
class LLMStreamResponse:
"""流式响应包装器,可迭代文本片段,结束后读取 usage。"""
def __init__(self, iterator: Iterable):
self._iterator = iterator
self.usage: Dict = {}
def __iter__(self):
for chunk in self._iterator:
self._update_usage(chunk)
text = _extract_text(chunk)
if text:
yield text
def _update_usage(self, chunk):
try:
if chunk.usage:
self.usage = chunk.usage
except Exception:
return
def chat_completion(
messages: List[Dict[str, str]],
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
seed: Optional[int] = None,
) -> LLMResult:
"""同步调用,返回完整文本。"""
api_key = settings.DASHSCOPE_API_KEY
if not api_key:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
model_name = model or settings.LLM_MODEL or "qwen-flash"
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
call_kwargs = {
"api_key": api_key,
"model": model_name,
"messages": messages,
"result_format": "message",
"temperature": temp,
"max_tokens": max_out,
}
if seed is not None:
call_kwargs["seed"] = seed
resp = Generation.call(**call_kwargs)
if getattr(resp, "status_code", 200) != 200:
detail = getattr(resp, "message", None) or "LLM 调用失败"
raise HTTPException(status_code=502, detail=detail)
content = _extract_text(resp)
return LLMResult(content=content, usage=getattr(resp, "usage", {}) or {})
def chat_completion_stream(
messages: List[Dict[str, str]],
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
seed: Optional[int] = None,
) -> LLMStreamResponse:
"""流式调用,迭代文本片段,调用结束后可读取 .usage。"""
api_key = settings.DASHSCOPE_API_KEY
if not api_key:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
model_name = model or settings.LLM_MODEL or "qwen-flash"
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
call_kwargs = {
"api_key": api_key,
"model": model_name,
"messages": messages,
"result_format": "message",
"stream": True,
"incremental_output": True,
"temperature": temp,
"max_tokens": max_out,
}
if seed is not None:
call_kwargs["seed"] = seed
iterator = Generation.call(**call_kwargs)
return LLMStreamResponse(iterator)
def _extract_text(resp_obj: object) -> str:
"""兼容 dashscope 同步/流式的文本字段提取。"""
try:
choice = resp_obj.output.choices[0].message.content
except Exception:
return ""
# content 可能是字符串或 list[{text: ...}]
if isinstance(choice, str):
return choice
if isinstance(choice, list):
parts = []
for item in choice:
if isinstance(item, dict) and "text" in item:
parts.append(str(item.get("text") or ""))
return "".join(parts)
return ""