134 lines
3.9 KiB
Python
134 lines
3.9 KiB
Python
"""轻量封装通义千问 (DashScope) 文本生成。
|
||
|
||
目前主要用于聊天初始化/对话,默认模型 qwen-flash,可通过 .env 覆盖。
|
||
"""
|
||
from typing import Dict, Iterable, List, Optional
|
||
|
||
import dashscope
|
||
from dashscope import Generation
|
||
from fastapi import HTTPException
|
||
|
||
from .config import settings
|
||
|
||
|
||
class LLMResult:
|
||
"""非流式响应结果。"""
|
||
|
||
def __init__(self, content: str, usage: Optional[dict]):
|
||
self.content = content
|
||
self.usage = usage or {}
|
||
|
||
|
||
class LLMStreamResponse:
|
||
"""流式响应包装器,可迭代文本片段,结束后读取 usage。"""
|
||
|
||
def __init__(self, iterator: Iterable):
|
||
self._iterator = iterator
|
||
self.usage: Dict = {}
|
||
|
||
def __iter__(self):
|
||
for chunk in self._iterator:
|
||
self._update_usage(chunk)
|
||
text = _extract_text(chunk)
|
||
if text:
|
||
yield text
|
||
|
||
def _update_usage(self, chunk):
|
||
try:
|
||
if chunk.usage:
|
||
self.usage = chunk.usage
|
||
except Exception:
|
||
return
|
||
|
||
|
||
def chat_completion(
|
||
messages: List[Dict[str, str]],
|
||
*,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
seed: Optional[int] = None,
|
||
) -> LLMResult:
|
||
"""同步调用,返回完整文本。"""
|
||
|
||
api_key = settings.DASHSCOPE_API_KEY
|
||
if not api_key:
|
||
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
|
||
|
||
model_name = model or settings.LLM_MODEL or "qwen-flash"
|
||
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
|
||
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
|
||
|
||
call_kwargs = {
|
||
"api_key": api_key,
|
||
"model": model_name,
|
||
"messages": messages,
|
||
"result_format": "message",
|
||
"temperature": temp,
|
||
"max_tokens": max_out,
|
||
}
|
||
if seed is not None:
|
||
call_kwargs["seed"] = seed
|
||
|
||
resp = Generation.call(**call_kwargs)
|
||
if getattr(resp, "status_code", 200) != 200:
|
||
detail = getattr(resp, "message", None) or "LLM 调用失败"
|
||
raise HTTPException(status_code=502, detail=detail)
|
||
|
||
content = _extract_text(resp)
|
||
return LLMResult(content=content, usage=getattr(resp, "usage", {}) or {})
|
||
|
||
|
||
def chat_completion_stream(
|
||
messages: List[Dict[str, str]],
|
||
*,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
seed: Optional[int] = None,
|
||
) -> LLMStreamResponse:
|
||
"""流式调用,迭代文本片段,调用结束后可读取 .usage。"""
|
||
|
||
api_key = settings.DASHSCOPE_API_KEY
|
||
if not api_key:
|
||
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
|
||
|
||
model_name = model or settings.LLM_MODEL or "qwen-flash"
|
||
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
|
||
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
|
||
|
||
call_kwargs = {
|
||
"api_key": api_key,
|
||
"model": model_name,
|
||
"messages": messages,
|
||
"result_format": "message",
|
||
"stream": True,
|
||
"incremental_output": True,
|
||
"temperature": temp,
|
||
"max_tokens": max_out,
|
||
}
|
||
if seed is not None:
|
||
call_kwargs["seed"] = seed
|
||
|
||
iterator = Generation.call(**call_kwargs)
|
||
return LLMStreamResponse(iterator)
|
||
|
||
|
||
def _extract_text(resp_obj: object) -> str:
|
||
"""兼容 dashscope 同步/流式的文本字段提取。"""
|
||
try:
|
||
choice = resp_obj.output.choices[0].message.content
|
||
except Exception:
|
||
return ""
|
||
|
||
# content 可能是字符串或 list[{text: ...}]
|
||
if isinstance(choice, str):
|
||
return choice
|
||
if isinstance(choice, list):
|
||
parts = []
|
||
for item in choice:
|
||
if isinstance(item, dict) and "text" in item:
|
||
parts.append(str(item.get("text") or ""))
|
||
return "".join(parts)
|
||
return ""
|