from __future__ import annotations import logging import os from dataclasses import dataclass from typing import Any, Iterator from openai import OpenAI from config.prompt_config import load_prompt from tools import tool_registry from utils.env import env_float, env_int, env_str logger = logging.getLogger(__name__) def _accumulate_tool_calls(buffer: dict[int, dict], delta_tool_calls: list) -> None: """将流式增量 tool_calls 合并到 buffer[index] 中。""" for tc_delta in delta_tool_calls: idx = tc_delta.index if idx not in buffer: buffer[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""} if tc_delta.id: buffer[idx]["id"] = tc_delta.id fn = tc_delta.function if fn: if fn.name: buffer[idx]["name"] += fn.name if fn.arguments: buffer[idx]["arguments"] += fn.arguments DEFAULT_LLM_TIMEOUT_SECONDS = 1800 DEFAULT_LLM_TEMPERATURE = 0 def _coalesce(*values): for value in values: if value is not None: return value return None @dataclass(frozen=True) class LocalLLMSettings: api_key: str model: str base_url: str timeout_seconds: int default_temperature: float thinking_type: str def _load_settings() -> LocalLLMSettings: api_key = env_str("LOCAL_LLM_API_KEY") model = env_str("LOCAL_LLM_MODEL") base_url = env_str("LOCAL_LLM_BASE_URL") if not api_key: raise ValueError("LOCAL_LLM_API_KEY 不能为空") if not model: raise ValueError("LOCAL_LLM_MODEL 不能为空") if not base_url: raise ValueError("LOCAL_LLM_BASE_URL 不能为空") return LocalLLMSettings( api_key=api_key, model=model, base_url=base_url, timeout_seconds=env_int( "LOCAL_LLM_TIMEOUT_SECONDS", DEFAULT_LLM_TIMEOUT_SECONDS ), default_temperature=env_float( "LOCAL_LLM_TEMPERATURE", DEFAULT_LLM_TEMPERATURE ), thinking_type=env_str("CUSTOM_LLM_THINKING_TYPE", "disabled"), ) class LocalLLMService: def __init__(self): self._client = None self._settings: LocalLLMSettings | None = None @property def settings(self) -> LocalLLMSettings: if self._settings is None: self._settings = _load_settings() return self._settings def _get_client(self): if self._client is not None: return self._client s = self.settings # 显式传入以确保使用 settings 中加载的值(支持多环境覆盖) self._client = OpenAI( api_key=s.api_key, base_url=s.base_url, timeout=s.timeout_seconds, ) return self._client def chat_stream( self, history_messages: list[dict[str, Any]], request_options: dict[str, Any] | None = None, prompt_name: str = "default", ) -> Iterator[Any]: settings = self.settings client = self._get_client() request_options = request_options or {} messages = [ { "role": "system", "content": load_prompt(prompt_name), } ] # 过滤掉 history 中已有的 system 消息,避免重复 messages.extend( msg for msg in history_messages if msg.get("role") != "system" ) payload: dict[str, Any] = { "model": settings.model, "messages": messages, "temperature": _coalesce( request_options.get("temperature"), settings.default_temperature, ), "stream": True, "stream_options": {"include_usage": True}, } if request_options.get("max_tokens") is not None: payload["max_tokens"] = request_options["max_tokens"] if request_options.get("top_p") is not None: payload["top_p"] = request_options["top_p"] # thinking 配置通过 extra_body 传递 payload["extra_body"] = {"thinking": {"type": settings.thinking_type}} if settings.thinking_type != "disabled": payload["reasoning_effort"] = "medium" # 可选值:low / medium / high # 注入 tools tools_enabled = os.getenv("TOOLS_ENABLED", "false").lower() == "true" tools_list = tool_registry.get_openai_tools() if tools_enabled else [] if tools_list: payload["tools"] = tools_list max_rounds = int(os.getenv("TOOLS_MAX_ROUNDS", "3")) import json as _json logger.info("发起流式调用 (Model: %s, Thinking: %s)", settings.model, settings.thinking_type) logger.debug("请求 payload:\n%s", _json.dumps(payload, ensure_ascii=False, indent=2)) try: for round_idx in range(max_rounds): logger.info("LLM 调用第 %d 轮 (tools=%d)", round_idx + 1, len(tools_list)) tc_buffer: dict[int, dict] = {} has_tool_calls = False stream = client.chat.completions.create(**payload) for chunk in stream: if not chunk.choices: if not has_tool_calls: yield chunk continue choice = chunk.choices[0] delta = choice.delta # 收集 tool_calls 增量片段(不 yield 给客户端) if delta and delta.tool_calls: has_tool_calls = True _accumulate_tool_calls(tc_buffer, delta.tool_calls) continue # 正常文本内容直接 yield if not has_tool_calls: yield chunk if not tc_buffer: break # 无 tool_calls,结束循环 # 构造 assistant message(含 tool_calls) sorted_tcs = sorted(tc_buffer.items()) tool_calls_for_msg = [ { "id": tc["id"], "type": "function", "function": {"name": tc["name"], "arguments": tc["arguments"]}, } for _, tc in sorted_tcs ] payload["messages"].append({ "role": "assistant", "content": None, "tool_calls": tool_calls_for_msg, }) # 执行工具并追加结果 for tc in tool_calls_for_msg: fn_name = tc["function"]["name"] fn_args = tc["function"]["arguments"] logger.info("执行工具: %s(%s)", fn_name, fn_args) result = tool_registry.execute(fn_name, fn_args) logger.info("工具结果: %s → %s", fn_name, result[:200]) payload["messages"].append({ "role": "tool", "tool_call_id": tc["id"], "content": result, }) # 继续下一轮 LLM 调用 except Exception as exc: logger.error("LLM 调用失败: %s", exc) raise local_llm_service = LocalLLMService()