Ai_GirlFriend/lover/llm.py
2026-01-31 19:15:41 +08:00

134 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""轻量封装通义千问 (DashScope) 文本生成。
目前主要用于聊天初始化/对话,默认模型 qwen-flash可通过 .env 覆盖。
"""
from typing import Dict, Iterable, List, Optional
import dashscope
from dashscope import Generation
from fastapi import HTTPException
from .config import settings
class LLMResult:
"""非流式响应结果。"""
def __init__(self, content: str, usage: Optional[dict]):
self.content = content
self.usage = usage or {}
class LLMStreamResponse:
"""流式响应包装器,可迭代文本片段,结束后读取 usage。"""
def __init__(self, iterator: Iterable):
self._iterator = iterator
self.usage: Dict = {}
def __iter__(self):
for chunk in self._iterator:
self._update_usage(chunk)
text = _extract_text(chunk)
if text:
yield text
def _update_usage(self, chunk):
try:
if chunk.usage:
self.usage = chunk.usage
except Exception:
return
def chat_completion(
messages: List[Dict[str, str]],
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
seed: Optional[int] = None,
) -> LLMResult:
"""同步调用,返回完整文本。"""
api_key = settings.DASHSCOPE_API_KEY
if not api_key:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
model_name = model or settings.LLM_MODEL or "qwen-flash"
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
call_kwargs = {
"api_key": api_key,
"model": model_name,
"messages": messages,
"result_format": "message",
"temperature": temp,
"max_tokens": max_out,
}
if seed is not None:
call_kwargs["seed"] = seed
resp = Generation.call(**call_kwargs)
if getattr(resp, "status_code", 200) != 200:
detail = getattr(resp, "message", None) or "LLM 调用失败"
raise HTTPException(status_code=502, detail=detail)
content = _extract_text(resp)
return LLMResult(content=content, usage=getattr(resp, "usage", {}) or {})
def chat_completion_stream(
messages: List[Dict[str, str]],
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
seed: Optional[int] = None,
) -> LLMStreamResponse:
"""流式调用,迭代文本片段,调用结束后可读取 .usage。"""
api_key = settings.DASHSCOPE_API_KEY
if not api_key:
raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY")
model_name = model or settings.LLM_MODEL or "qwen-flash"
temp = temperature if temperature is not None else settings.LLM_TEMPERATURE
max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS
call_kwargs = {
"api_key": api_key,
"model": model_name,
"messages": messages,
"result_format": "message",
"stream": True,
"incremental_output": True,
"temperature": temp,
"max_tokens": max_out,
}
if seed is not None:
call_kwargs["seed"] = seed
iterator = Generation.call(**call_kwargs)
return LLMStreamResponse(iterator)
def _extract_text(resp_obj: object) -> str:
"""兼容 dashscope 同步/流式的文本字段提取。"""
try:
choice = resp_obj.output.choices[0].message.content
except Exception:
return ""
# content 可能是字符串或 list[{text: ...}]
if isinstance(choice, str):
return choice
if isinstance(choice, list):
parts = []
for item in choice:
if isinstance(item, dict) and "text" in item:
parts.append(str(item.get("text") or ""))
return "".join(parts)
return ""