221 lines
7.2 KiB
Python
221 lines
7.2 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
from dataclasses import dataclass
|
||
from typing import Any, Iterator
|
||
from openai import OpenAI
|
||
from config.prompt_config import load_prompt
|
||
from tools import tool_registry
|
||
from utils.env import env_float, env_int, env_str
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _accumulate_tool_calls(buffer: dict[int, dict], delta_tool_calls: list) -> None:
|
||
"""将流式增量 tool_calls 合并到 buffer[index] 中。"""
|
||
for tc_delta in delta_tool_calls:
|
||
idx = tc_delta.index
|
||
if idx not in buffer:
|
||
buffer[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""}
|
||
if tc_delta.id:
|
||
buffer[idx]["id"] = tc_delta.id
|
||
fn = tc_delta.function
|
||
if fn:
|
||
if fn.name:
|
||
buffer[idx]["name"] += fn.name
|
||
if fn.arguments:
|
||
buffer[idx]["arguments"] += fn.arguments
|
||
|
||
DEFAULT_LLM_TIMEOUT_SECONDS = 1800
|
||
DEFAULT_LLM_TEMPERATURE = 0
|
||
|
||
|
||
def _coalesce(*values):
|
||
for value in values:
|
||
if value is not None:
|
||
return value
|
||
return None
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class LocalLLMSettings:
|
||
api_key: str
|
||
model: str
|
||
base_url: str
|
||
timeout_seconds: int
|
||
default_temperature: float
|
||
thinking_type: str
|
||
|
||
|
||
def _load_settings() -> LocalLLMSettings:
|
||
api_key = env_str("LOCAL_LLM_API_KEY")
|
||
model = env_str("LOCAL_LLM_MODEL")
|
||
base_url = env_str("LOCAL_LLM_BASE_URL")
|
||
|
||
if not api_key:
|
||
raise ValueError("LOCAL_LLM_API_KEY 不能为空")
|
||
if not model:
|
||
raise ValueError("LOCAL_LLM_MODEL 不能为空")
|
||
if not base_url:
|
||
raise ValueError("LOCAL_LLM_BASE_URL 不能为空")
|
||
|
||
return LocalLLMSettings(
|
||
api_key=api_key,
|
||
model=model,
|
||
base_url=base_url,
|
||
timeout_seconds=env_int(
|
||
"LOCAL_LLM_TIMEOUT_SECONDS", DEFAULT_LLM_TIMEOUT_SECONDS
|
||
),
|
||
default_temperature=env_float(
|
||
"LOCAL_LLM_TEMPERATURE", DEFAULT_LLM_TEMPERATURE
|
||
),
|
||
thinking_type=env_str("CUSTOM_LLM_THINKING_TYPE", "disabled"),
|
||
)
|
||
|
||
|
||
class LocalLLMService:
|
||
def __init__(self):
|
||
self._client = None
|
||
self._settings: LocalLLMSettings | None = None
|
||
|
||
@property
|
||
def settings(self) -> LocalLLMSettings:
|
||
if self._settings is None:
|
||
self._settings = _load_settings()
|
||
return self._settings
|
||
|
||
def _get_client(self):
|
||
if self._client is not None:
|
||
return self._client
|
||
|
||
s = self.settings
|
||
# 显式传入以确保使用 settings 中加载的值(支持多环境覆盖)
|
||
self._client = OpenAI(
|
||
api_key=s.api_key,
|
||
base_url=s.base_url,
|
||
timeout=s.timeout_seconds,
|
||
)
|
||
return self._client
|
||
|
||
def chat_stream(
|
||
self,
|
||
history_messages: list[dict[str, Any]],
|
||
request_options: dict[str, Any] | None = None,
|
||
prompt_name: str = "default",
|
||
) -> Iterator[Any]:
|
||
settings = self.settings
|
||
client = self._get_client()
|
||
request_options = request_options or {}
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": load_prompt(prompt_name),
|
||
}
|
||
]
|
||
# 过滤掉 history 中已有的 system 消息,避免重复
|
||
messages.extend(
|
||
msg for msg in history_messages if msg.get("role") != "system"
|
||
)
|
||
|
||
payload: dict[str, Any] = {
|
||
"model": settings.model,
|
||
"messages": messages,
|
||
"temperature": _coalesce(
|
||
request_options.get("temperature"),
|
||
settings.default_temperature,
|
||
),
|
||
"stream": True,
|
||
"stream_options": {"include_usage": True},
|
||
}
|
||
|
||
if request_options.get("max_tokens") is not None:
|
||
payload["max_tokens"] = request_options["max_tokens"]
|
||
if request_options.get("top_p") is not None:
|
||
payload["top_p"] = request_options["top_p"]
|
||
|
||
# thinking 配置通过 extra_body 传递
|
||
payload["extra_body"] = {"thinking": {"type": settings.thinking_type}}
|
||
if settings.thinking_type != "disabled":
|
||
payload["reasoning_effort"] = "medium" # 可选值:low / medium / high
|
||
|
||
# 注入 tools
|
||
tools_enabled = os.getenv("TOOLS_ENABLED", "false").lower() == "true"
|
||
tools_list = tool_registry.get_openai_tools() if tools_enabled else []
|
||
if tools_list:
|
||
payload["tools"] = tools_list
|
||
|
||
max_rounds = int(os.getenv("TOOLS_MAX_ROUNDS", "3"))
|
||
|
||
import json as _json
|
||
logger.info("发起流式调用 (Model: %s, Thinking: %s)", settings.model, settings.thinking_type)
|
||
logger.debug("请求 payload:\n%s", _json.dumps(payload, ensure_ascii=False, indent=2))
|
||
|
||
try:
|
||
for round_idx in range(max_rounds):
|
||
logger.info("LLM 调用第 %d 轮 (tools=%d)", round_idx + 1, len(tools_list))
|
||
|
||
tc_buffer: dict[int, dict] = {}
|
||
has_tool_calls = False
|
||
|
||
stream = client.chat.completions.create(**payload)
|
||
|
||
for chunk in stream:
|
||
if not chunk.choices:
|
||
if not has_tool_calls:
|
||
yield chunk
|
||
continue
|
||
|
||
choice = chunk.choices[0]
|
||
delta = choice.delta
|
||
|
||
# 收集 tool_calls 增量片段(不 yield 给客户端)
|
||
if delta and delta.tool_calls:
|
||
has_tool_calls = True
|
||
_accumulate_tool_calls(tc_buffer, delta.tool_calls)
|
||
continue
|
||
|
||
# 正常文本内容直接 yield
|
||
if not has_tool_calls:
|
||
yield chunk
|
||
|
||
if not tc_buffer:
|
||
break # 无 tool_calls,结束循环
|
||
|
||
# 构造 assistant message(含 tool_calls)
|
||
sorted_tcs = sorted(tc_buffer.items())
|
||
tool_calls_for_msg = [
|
||
{
|
||
"id": tc["id"],
|
||
"type": "function",
|
||
"function": {"name": tc["name"], "arguments": tc["arguments"]},
|
||
}
|
||
for _, tc in sorted_tcs
|
||
]
|
||
payload["messages"].append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls_for_msg,
|
||
})
|
||
|
||
# 执行工具并追加结果
|
||
for tc in tool_calls_for_msg:
|
||
fn_name = tc["function"]["name"]
|
||
fn_args = tc["function"]["arguments"]
|
||
logger.info("执行工具: %s(%s)", fn_name, fn_args)
|
||
result = tool_registry.execute(fn_name, fn_args)
|
||
logger.info("工具结果: %s → %s", fn_name, result[:200])
|
||
payload["messages"].append({
|
||
"role": "tool",
|
||
"tool_call_id": tc["id"],
|
||
"content": result,
|
||
})
|
||
# 继续下一轮 LLM 调用
|
||
except Exception as exc:
|
||
logger.error("LLM 调用失败: %s", exc)
|
||
raise
|
||
|
||
|
||
local_llm_service = LocalLLMService()
|