"""轻量封装通义千问 (DashScope) 文本生成。 目前主要用于聊天初始化/对话,默认模型 qwen-flash,可通过 .env 覆盖。 """ from typing import Dict, Iterable, List, Optional import dashscope from dashscope import Generation from fastapi import HTTPException from .config import settings class LLMResult: """非流式响应结果。""" def __init__(self, content: str, usage: Optional[dict]): self.content = content self.usage = usage or {} class LLMStreamResponse: """流式响应包装器,可迭代文本片段,结束后读取 usage。""" def __init__(self, iterator: Iterable): self._iterator = iterator self.usage: Dict = {} def __iter__(self): for chunk in self._iterator: self._update_usage(chunk) text = _extract_text(chunk) if text: yield text def _update_usage(self, chunk): try: if chunk.usage: self.usage = chunk.usage except Exception: return def chat_completion( messages: List[Dict[str, str]], *, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, seed: Optional[int] = None, ) -> LLMResult: """同步调用,返回完整文本。""" api_key = settings.DASHSCOPE_API_KEY if not api_key: raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY") model_name = model or settings.LLM_MODEL or "qwen-flash" temp = temperature if temperature is not None else settings.LLM_TEMPERATURE max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS call_kwargs = { "api_key": api_key, "model": model_name, "messages": messages, "result_format": "message", "temperature": temp, "max_tokens": max_out, } if seed is not None: call_kwargs["seed"] = seed resp = Generation.call(**call_kwargs) if getattr(resp, "status_code", 200) != 200: detail = getattr(resp, "message", None) or "LLM 调用失败" raise HTTPException(status_code=502, detail=detail) content = _extract_text(resp) return LLMResult(content=content, usage=getattr(resp, "usage", {}) or {}) def chat_completion_stream( messages: List[Dict[str, str]], *, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, seed: Optional[int] = None, ) -> LLMStreamResponse: """流式调用,迭代文本片段,调用结束后可读取 .usage。""" api_key = settings.DASHSCOPE_API_KEY if not api_key: raise HTTPException(status_code=500, detail="未配置 DASHSCOPE_API_KEY") model_name = model or settings.LLM_MODEL or "qwen-flash" temp = temperature if temperature is not None else settings.LLM_TEMPERATURE max_out = max_tokens if max_tokens is not None else settings.LLM_MAX_TOKENS call_kwargs = { "api_key": api_key, "model": model_name, "messages": messages, "result_format": "message", "stream": True, "incremental_output": True, "temperature": temp, "max_tokens": max_out, } if seed is not None: call_kwargs["seed"] = seed iterator = Generation.call(**call_kwargs) return LLMStreamResponse(iterator) def _extract_text(resp_obj: object) -> str: """兼容 dashscope 同步/流式的文本字段提取。""" try: choice = resp_obj.output.choices[0].message.content except Exception: return "" # content 可能是字符串或 list[{text: ...}] if isinstance(choice, str): return choice if isinstance(choice, list): parts = [] for item in choice: if isinstance(item, dict) and "text" in item: parts.append(str(item.get("text") or "")) return "".join(parts) return ""