140 lines
4.2 KiB
Python
140 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any, Iterator
|
|
from openai import OpenAI
|
|
from utils.env import env_float, env_int, env_str
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_LLM_TIMEOUT_SECONDS = 1800
|
|
DEFAULT_LLM_TEMPERATURE = 0
|
|
|
|
|
|
def _coalesce(*values):
|
|
for value in values:
|
|
if value is not None:
|
|
return value
|
|
return None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LocalLLMSettings:
|
|
api_key: str
|
|
model: str
|
|
base_url: str
|
|
timeout_seconds: int
|
|
system_prompt: str
|
|
default_temperature: float
|
|
thinking_type: str
|
|
|
|
|
|
def _load_settings() -> LocalLLMSettings:
|
|
api_key = env_str("LOCAL_LLM_API_KEY")
|
|
model = env_str("LOCAL_LLM_MODEL")
|
|
base_url = env_str("LOCAL_LLM_BASE_URL")
|
|
|
|
if not api_key:
|
|
raise ValueError("LOCAL_LLM_API_KEY 不能为空")
|
|
if not model:
|
|
raise ValueError("LOCAL_LLM_MODEL 不能为空")
|
|
if not base_url:
|
|
raise ValueError("LOCAL_LLM_BASE_URL 不能为空")
|
|
|
|
return LocalLLMSettings(
|
|
api_key=api_key,
|
|
model=model,
|
|
base_url=base_url,
|
|
timeout_seconds=env_int(
|
|
"LOCAL_LLM_TIMEOUT_SECONDS", DEFAULT_LLM_TIMEOUT_SECONDS
|
|
),
|
|
system_prompt=env_str("CUSTOM_LLM_SYSTEM_MESSAGE"),
|
|
default_temperature=env_float(
|
|
"LOCAL_LLM_TEMPERATURE", DEFAULT_LLM_TEMPERATURE
|
|
),
|
|
thinking_type=env_str("CUSTOM_LLM_THINKING_TYPE", "disabled"),
|
|
)
|
|
|
|
|
|
class LocalLLMService:
|
|
def __init__(self):
|
|
self._client = None
|
|
self._settings: LocalLLMSettings | None = None
|
|
|
|
@property
|
|
def settings(self) -> LocalLLMSettings:
|
|
if self._settings is None:
|
|
self._settings = _load_settings()
|
|
return self._settings
|
|
|
|
def _get_client(self):
|
|
if self._client is not None:
|
|
return self._client
|
|
|
|
s = self.settings
|
|
# 显式传入以确保使用 settings 中加载的值(支持多环境覆盖)
|
|
self._client = OpenAI(
|
|
api_key=s.api_key,
|
|
base_url=s.base_url,
|
|
timeout=s.timeout_seconds,
|
|
)
|
|
return self._client
|
|
|
|
def chat_stream(
|
|
self,
|
|
history_messages: list[dict[str, Any]],
|
|
rag_context: str = "",
|
|
request_options: dict[str, Any] | None = None,
|
|
) -> Iterator[Any]:
|
|
settings = self.settings
|
|
client = self._get_client()
|
|
request_options = request_options or {}
|
|
|
|
system_blocks = [settings.system_prompt]
|
|
if rag_context:
|
|
system_blocks.append(f"### 参考知识库(绝对准则)\n{rag_context.strip()}")
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "\n\n".join(block for block in system_blocks if block),
|
|
}
|
|
]
|
|
messages.extend(history_messages)
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": settings.model,
|
|
"messages": messages,
|
|
"temperature": _coalesce(
|
|
request_options.get("temperature"),
|
|
settings.default_temperature,
|
|
),
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
}
|
|
|
|
if request_options.get("max_tokens") is not None:
|
|
payload["max_tokens"] = request_options["max_tokens"]
|
|
if request_options.get("top_p") is not None:
|
|
payload["top_p"] = request_options["top_p"]
|
|
|
|
# thinking 仅在启用时通过 extra_body 传递(避免不支持的模型报错)
|
|
if settings.thinking_type != "disabled":
|
|
payload["extra_body"] = {"thinking": {"type": settings.thinking_type}}
|
|
payload["reasoning_effort"] = "low"
|
|
|
|
import json as _json
|
|
logger.info("发起流式调用 (Model: %s, Thinking: %s)", settings.model, settings.thinking_type)
|
|
logger.debug("请求 payload:\n%s", _json.dumps(payload, ensure_ascii=False, indent=2))
|
|
try:
|
|
stream = client.chat.completions.create(**payload)
|
|
for chunk in stream:
|
|
yield chunk
|
|
except Exception as exc:
|
|
logger.error("LLM 调用失败: %s", exc)
|
|
raise
|
|
|
|
|
|
local_llm_service = LocalLLMService()
|