140 lines
3.9 KiB
Python
140 lines
3.9 KiB
Python
"""
|
||
本地 CustomLLM 服务。
|
||
|
||
当前实现直接在同一个 FastAPI 进程内调用方舟 SDK,
|
||
并由 /api/chat_callback 对外提供火山要求的 SSE 回调接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Any, Iterator
|
||
|
||
from utils.env import env_float, env_int, env_str
|
||
|
||
DEFAULT_ARK_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
DEFAULT_ARK_TIMEOUT_SECONDS = 1800
|
||
DEFAULT_ARK_TEMPERATURE = 0
|
||
|
||
|
||
def _coalesce(*values):
|
||
for value in values:
|
||
if value is not None:
|
||
return value
|
||
return None
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class LocalLLMSettings:
|
||
api_key: str
|
||
endpoint_id: str
|
||
base_url: str
|
||
timeout_seconds: int
|
||
system_prompt: str
|
||
default_temperature: float
|
||
|
||
|
||
def _load_settings() -> LocalLLMSettings:
|
||
api_key = env_str("ARK_API_KEY")
|
||
endpoint_id = env_str("ARK_ENDPOINT_ID")
|
||
|
||
if not api_key:
|
||
raise ValueError("ARK_API_KEY 不能为空")
|
||
if not endpoint_id:
|
||
raise ValueError("ARK_ENDPOINT_ID 不能为空")
|
||
|
||
return LocalLLMSettings(
|
||
api_key=api_key,
|
||
endpoint_id=endpoint_id,
|
||
base_url=env_str("ARK_BASE_URL", DEFAULT_ARK_BASE_URL),
|
||
timeout_seconds=env_int("ARK_TIMEOUT_SECONDS", DEFAULT_ARK_TIMEOUT_SECONDS),
|
||
system_prompt=env_str(
|
||
"LOCAL_LLM_SYSTEM_PROMPT",
|
||
env_str("CUSTOM_LLM_SYSTEM_MESSAGE"),
|
||
),
|
||
default_temperature=env_float(
|
||
"LOCAL_LLM_TEMPERATURE", DEFAULT_ARK_TEMPERATURE
|
||
),
|
||
)
|
||
|
||
|
||
class LocalLLMService:
|
||
def __init__(self):
|
||
self._client = None
|
||
self._settings: LocalLLMSettings | None = None
|
||
|
||
@property
|
||
def settings(self) -> LocalLLMSettings:
|
||
if self._settings is None:
|
||
self._settings = _load_settings()
|
||
return self._settings
|
||
|
||
def _get_client(self):
|
||
if self._client is not None:
|
||
return self._client
|
||
|
||
try:
|
||
from volcenginesdkarkruntime import Ark
|
||
except ImportError as exc:
|
||
raise RuntimeError(
|
||
"未安装 volcenginesdkarkruntime,请先执行 uv sync 安装依赖"
|
||
) from exc
|
||
|
||
s = self.settings
|
||
self._client = Ark(
|
||
base_url=s.base_url,
|
||
api_key=s.api_key,
|
||
timeout=s.timeout_seconds,
|
||
)
|
||
return self._client
|
||
|
||
def chat_stream(
|
||
self,
|
||
history_messages: list[dict[str, Any]],
|
||
rag_context: str = "",
|
||
request_options: dict[str, Any] | None = None,
|
||
) -> Iterator[Any]:
|
||
settings = self.settings
|
||
client = self._get_client()
|
||
request_options = request_options or {}
|
||
|
||
system_blocks = [settings.system_prompt]
|
||
if rag_context:
|
||
system_blocks.append(f"### 参考知识库(绝对准则)\n{rag_context.strip()}")
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "\n\n".join(block for block in system_blocks if block),
|
||
}
|
||
]
|
||
messages.extend(history_messages)
|
||
|
||
payload: dict[str, Any] = {
|
||
"model": settings.endpoint_id,
|
||
"messages": messages,
|
||
"temperature": _coalesce(
|
||
request_options.get("temperature"),
|
||
settings.default_temperature,
|
||
),
|
||
"stream": True,
|
||
"stream_options": {"include_usage": True},
|
||
}
|
||
|
||
if request_options.get("max_tokens") is not None:
|
||
payload["max_tokens"] = request_options["max_tokens"]
|
||
if request_options.get("top_p") is not None:
|
||
payload["top_p"] = request_options["top_p"]
|
||
|
||
print(f"🚀 发起流式调用 (Endpoint: {settings.endpoint_id})")
|
||
try:
|
||
stream = client.chat.completions.create(**payload)
|
||
for chunk in stream:
|
||
yield chunk
|
||
except Exception as exc:
|
||
print(f"❌ LLM 调用失败: {exc}")
|
||
raise
|
||
|
||
|
||
local_llm_service = LocalLLMService()
|