140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
"""
|
||
POST /api/chat_callback — 自定义 LLM 回调(SSE 流式响应)
|
||
"""
|
||
|
||
import json
|
||
|
||
from fastapi import APIRouter, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from schemas.chat import ChatCallbackRequest
|
||
from services.local_llm_service import local_llm_service
|
||
from services.scene_service import ensure_custom_llm_authorized, get_custom_llm_callback_settings
|
||
from services.session_store import get_room_history
|
||
from utils.responses import custom_llm_error_response
|
||
|
||
router = APIRouter(tags=["LLM 回调"])
|
||
|
||
|
||
@router.post(
|
||
"/api/chat_callback",
|
||
summary="自定义 LLM 回调(SSE 流式)",
|
||
description=(
|
||
"由**火山引擎 RTC 平台**在用户发言后自动回调,返回 OpenAI 兼容格式的 SSE 流。\n\n"
|
||
"处理逻辑:\n"
|
||
"1. 校验 `Authorization: Bearer <CUSTOM_LLM_API_KEY>`\n"
|
||
"2. 过滤掉 RTC 平台发送的 `欢迎语` 触发词(非真实用户输入)\n"
|
||
"3. 若携带 `room_id` Query 参数,自动从缓存取历史并 prepend 到 messages 前\n"
|
||
"4. 调用本地 LLM(工具调用 / RAG 按需触发),以 SSE 流返回结果\n\n"
|
||
"**鉴权**:`Authorization: Bearer <CUSTOM_LLM_API_KEY>`"
|
||
),
|
||
responses={
|
||
401: {"description": "API Key 无效"},
|
||
400: {"description": "messages 为空或最后一条不是 user 角色"},
|
||
500: {"description": "LLM 初始化失败"},
|
||
},
|
||
)
|
||
async def chat_callback(request: Request, body: ChatCallbackRequest):
|
||
try:
|
||
settings = get_custom_llm_callback_settings()
|
||
ensure_custom_llm_authorized(request, settings["api_key"])
|
||
except PermissionError as exc:
|
||
return custom_llm_error_response(
|
||
str(exc),
|
||
code="AuthenticationError",
|
||
status_code=401,
|
||
)
|
||
except ValueError as exc:
|
||
return custom_llm_error_response(str(exc))
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"解析请求失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
messages = [m.model_dump() for m in body.messages]
|
||
if not messages:
|
||
return custom_llm_error_response(
|
||
"messages 不能为空",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
# 过滤 RTC 平台的"欢迎语"触发词(不是真实用户输入)
|
||
messages = [m for m in messages if not (m["role"] == "user" and m["content"] == "欢迎语")]
|
||
|
||
# 注入历史对话上下文(prepend 到当前会话消息前)
|
||
room_id = request.query_params.get("room_id", "")
|
||
if room_id:
|
||
history = get_room_history(room_id)
|
||
if history:
|
||
messages = history + messages
|
||
|
||
if not messages:
|
||
return custom_llm_error_response(
|
||
"messages 不能为空",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
last_message = messages[-1]
|
||
if last_message.get("role") != "user":
|
||
return custom_llm_error_response(
|
||
"最后一条消息必须是用户消息",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
# RAG 已改为 tool 按需调用,不再预检索
|
||
|
||
try:
|
||
stream_iterator = local_llm_service.chat_stream(
|
||
history_messages=messages,
|
||
request_options={
|
||
"temperature": body.temperature,
|
||
"max_tokens": body.max_tokens,
|
||
"top_p": body.top_p,
|
||
},
|
||
)
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"初始化本地 LLM 流式调用失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
def generate_sse():
|
||
has_error = False
|
||
try:
|
||
for chunk in stream_iterator:
|
||
if chunk is None:
|
||
continue
|
||
|
||
if hasattr(chunk, "model_dump_json"):
|
||
chunk_json = chunk.model_dump_json()
|
||
else:
|
||
chunk_json = json.dumps(chunk, ensure_ascii=False)
|
||
yield f"data: {chunk_json}\n\n"
|
||
except GeneratorExit:
|
||
raise
|
||
except Exception as exc:
|
||
has_error = True
|
||
print(f"❌ /api/chat_callback 流式输出失败: {exc}")
|
||
|
||
if has_error:
|
||
print("⚠️ 已提前结束当前 SSE 流")
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_sse(),
|
||
status_code=200,
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"Access-Control-Allow-Origin": "*",
|
||
},
|
||
)
|