rtc-voice-chat/backend/services/local_llm_service.py
2026-04-02 09:40:23 +08:00

221 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from typing import Any, Iterator
from openai import OpenAI
from config.prompt_config import load_prompt
from tools import tool_registry
from utils.env import env_float, env_int, env_str
logger = logging.getLogger(__name__)
def _accumulate_tool_calls(buffer: dict[int, dict], delta_tool_calls: list) -> None:
"""将流式增量 tool_calls 合并到 buffer[index] 中。"""
for tc_delta in delta_tool_calls:
idx = tc_delta.index
if idx not in buffer:
buffer[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""}
if tc_delta.id:
buffer[idx]["id"] = tc_delta.id
fn = tc_delta.function
if fn:
if fn.name:
buffer[idx]["name"] += fn.name
if fn.arguments:
buffer[idx]["arguments"] += fn.arguments
DEFAULT_LLM_TIMEOUT_SECONDS = 1800
DEFAULT_LLM_TEMPERATURE = 0
def _coalesce(*values):
for value in values:
if value is not None:
return value
return None
@dataclass(frozen=True)
class LocalLLMSettings:
api_key: str
model: str
base_url: str
timeout_seconds: int
default_temperature: float
thinking_type: str
def _load_settings() -> LocalLLMSettings:
api_key = env_str("LOCAL_LLM_API_KEY")
model = env_str("LOCAL_LLM_MODEL")
base_url = env_str("LOCAL_LLM_BASE_URL")
if not api_key:
raise ValueError("LOCAL_LLM_API_KEY 不能为空")
if not model:
raise ValueError("LOCAL_LLM_MODEL 不能为空")
if not base_url:
raise ValueError("LOCAL_LLM_BASE_URL 不能为空")
return LocalLLMSettings(
api_key=api_key,
model=model,
base_url=base_url,
timeout_seconds=env_int(
"LOCAL_LLM_TIMEOUT_SECONDS", DEFAULT_LLM_TIMEOUT_SECONDS
),
default_temperature=env_float(
"LOCAL_LLM_TEMPERATURE", DEFAULT_LLM_TEMPERATURE
),
thinking_type=env_str("CUSTOM_LLM_THINKING_TYPE", "disabled"),
)
class LocalLLMService:
def __init__(self):
self._client = None
self._settings: LocalLLMSettings | None = None
@property
def settings(self) -> LocalLLMSettings:
if self._settings is None:
self._settings = _load_settings()
return self._settings
def _get_client(self):
if self._client is not None:
return self._client
s = self.settings
# 显式传入以确保使用 settings 中加载的值(支持多环境覆盖)
self._client = OpenAI(
api_key=s.api_key,
base_url=s.base_url,
timeout=s.timeout_seconds,
)
return self._client
def chat_stream(
self,
history_messages: list[dict[str, Any]],
request_options: dict[str, Any] | None = None,
prompt_name: str = "default",
) -> Iterator[Any]:
settings = self.settings
client = self._get_client()
request_options = request_options or {}
messages = [
{
"role": "system",
"content": load_prompt(prompt_name),
}
]
# 过滤掉 history 中已有的 system 消息,避免重复
messages.extend(
msg for msg in history_messages if msg.get("role") != "system"
)
payload: dict[str, Any] = {
"model": settings.model,
"messages": messages,
"temperature": _coalesce(
request_options.get("temperature"),
settings.default_temperature,
),
"stream": True,
"stream_options": {"include_usage": True},
}
if request_options.get("max_tokens") is not None:
payload["max_tokens"] = request_options["max_tokens"]
if request_options.get("top_p") is not None:
payload["top_p"] = request_options["top_p"]
# thinking 配置通过 extra_body 传递
payload["extra_body"] = {"thinking": {"type": settings.thinking_type}}
if settings.thinking_type != "disabled":
payload["reasoning_effort"] = "medium" # 可选值low / medium / high
# 注入 tools
tools_enabled = os.getenv("TOOLS_ENABLED", "false").lower() == "true"
tools_list = tool_registry.get_openai_tools() if tools_enabled else []
if tools_list:
payload["tools"] = tools_list
max_rounds = int(os.getenv("TOOLS_MAX_ROUNDS", "3"))
import json as _json
logger.info("发起流式调用 (Model: %s, Thinking: %s)", settings.model, settings.thinking_type)
logger.debug("请求 payload:\n%s", _json.dumps(payload, ensure_ascii=False, indent=2))
try:
for round_idx in range(max_rounds):
logger.info("LLM 调用第 %d 轮 (tools=%d)", round_idx + 1, len(tools_list))
tc_buffer: dict[int, dict] = {}
has_tool_calls = False
stream = client.chat.completions.create(**payload)
for chunk in stream:
if not chunk.choices:
if not has_tool_calls:
yield chunk
continue
choice = chunk.choices[0]
delta = choice.delta
# 收集 tool_calls 增量片段(不 yield 给客户端)
if delta and delta.tool_calls:
has_tool_calls = True
_accumulate_tool_calls(tc_buffer, delta.tool_calls)
continue
# 正常文本内容直接 yield
if not has_tool_calls:
yield chunk
if not tc_buffer:
break # 无 tool_calls结束循环
# 构造 assistant message含 tool_calls
sorted_tcs = sorted(tc_buffer.items())
tool_calls_for_msg = [
{
"id": tc["id"],
"type": "function",
"function": {"name": tc["name"], "arguments": tc["arguments"]},
}
for _, tc in sorted_tcs
]
payload["messages"].append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls_for_msg,
})
# 执行工具并追加结果
for tc in tool_calls_for_msg:
fn_name = tc["function"]["name"]
fn_args = tc["function"]["arguments"]
logger.info("执行工具: %s(%s)", fn_name, fn_args)
result = tool_registry.execute(fn_name, fn_args)
logger.info("工具结果: %s%s", fn_name, result[:200])
payload["messages"].append({
"role": "tool",
"tool_call_id": tc["id"],
"content": result,
})
# 继续下一轮 LLM 调用
except Exception as exc:
logger.error("LLM 调用失败: %s", exc)
raise
local_llm_service = LocalLLMService()