rtc-voice-chat/backend/services/local_llm_service.py

140 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
本地 CustomLLM 服务。
当前实现直接在同一个 FastAPI 进程内调用方舟 SDK
并由 /api/chat_callback 对外提供火山要求的 SSE 回调接口。
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterator
from utils.env import env_float, env_int, env_str
DEFAULT_ARK_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
DEFAULT_ARK_TIMEOUT_SECONDS = 1800
DEFAULT_ARK_TEMPERATURE = 0
def _coalesce(*values):
for value in values:
if value is not None:
return value
return None
@dataclass(frozen=True)
class LocalLLMSettings:
api_key: str
endpoint_id: str
base_url: str
timeout_seconds: int
system_prompt: str
default_temperature: float
def _load_settings() -> LocalLLMSettings:
api_key = env_str("ARK_API_KEY")
endpoint_id = env_str("ARK_ENDPOINT_ID")
if not api_key:
raise ValueError("ARK_API_KEY 不能为空")
if not endpoint_id:
raise ValueError("ARK_ENDPOINT_ID 不能为空")
return LocalLLMSettings(
api_key=api_key,
endpoint_id=endpoint_id,
base_url=env_str("ARK_BASE_URL", DEFAULT_ARK_BASE_URL),
timeout_seconds=env_int("ARK_TIMEOUT_SECONDS", DEFAULT_ARK_TIMEOUT_SECONDS),
system_prompt=env_str(
"LOCAL_LLM_SYSTEM_PROMPT",
env_str("CUSTOM_LLM_SYSTEM_MESSAGE"),
),
default_temperature=env_float(
"LOCAL_LLM_TEMPERATURE", DEFAULT_ARK_TEMPERATURE
),
)
class LocalLLMService:
def __init__(self):
self._client = None
self._settings: LocalLLMSettings | None = None
@property
def settings(self) -> LocalLLMSettings:
if self._settings is None:
self._settings = _load_settings()
return self._settings
def _get_client(self):
if self._client is not None:
return self._client
try:
from volcenginesdkarkruntime import Ark
except ImportError as exc:
raise RuntimeError(
"未安装 volcenginesdkarkruntime请先执行 uv sync 安装依赖"
) from exc
s = self.settings
self._client = Ark(
base_url=s.base_url,
api_key=s.api_key,
timeout=s.timeout_seconds,
)
return self._client
def chat_stream(
self,
history_messages: list[dict[str, Any]],
rag_context: str = "",
request_options: dict[str, Any] | None = None,
) -> Iterator[Any]:
settings = self.settings
client = self._get_client()
request_options = request_options or {}
system_blocks = [settings.system_prompt]
if rag_context:
system_blocks.append(f"### 参考知识库(绝对准则)\n{rag_context.strip()}")
messages = [
{
"role": "system",
"content": "\n\n".join(block for block in system_blocks if block),
}
]
messages.extend(history_messages)
payload: dict[str, Any] = {
"model": settings.endpoint_id,
"messages": messages,
"temperature": _coalesce(
request_options.get("temperature"),
settings.default_temperature,
),
"stream": True,
"stream_options": {"include_usage": True},
}
if request_options.get("max_tokens") is not None:
payload["max_tokens"] = request_options["max_tokens"]
if request_options.get("top_p") is not None:
payload["top_p"] = request_options["top_p"]
print(f"🚀 发起流式调用 (Endpoint: {settings.endpoint_id})")
try:
stream = client.chat.completions.create(**payload)
for chunk in stream:
yield chunk
except Exception as exc:
print(f"❌ LLM 调用失败: {exc}")
raise
local_llm_service = LocalLLMService()