105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
"""
|
||
POST /api/chat_callback — 自定义 LLM 回调(SSE 流式响应)
|
||
"""
|
||
|
||
import json
|
||
|
||
from fastapi import APIRouter, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from schemas.chat import ChatCallbackRequest
|
||
from services.local_llm_service import local_llm_service
|
||
from services.scene_service import ensure_custom_llm_authorized, get_custom_llm_callback_settings
|
||
from utils.responses import custom_llm_error_response
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/api/chat_callback")
|
||
async def chat_callback(request: Request, body: ChatCallbackRequest):
|
||
try:
|
||
settings = get_custom_llm_callback_settings()
|
||
ensure_custom_llm_authorized(request, settings["api_key"])
|
||
except PermissionError as exc:
|
||
return custom_llm_error_response(
|
||
str(exc),
|
||
code="AuthenticationError",
|
||
status_code=401,
|
||
)
|
||
except ValueError as exc:
|
||
return custom_llm_error_response(str(exc))
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"解析请求失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
messages = [m.model_dump() for m in body.messages]
|
||
if not messages:
|
||
return custom_llm_error_response(
|
||
"messages 不能为空",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
last_message = messages[-1]
|
||
if last_message.get("role") != "user":
|
||
return custom_llm_error_response(
|
||
"最后一条消息必须是用户消息",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
# RAG 已改为 tool 按需调用,不再预检索
|
||
|
||
try:
|
||
stream_iterator = local_llm_service.chat_stream(
|
||
history_messages=messages,
|
||
request_options={
|
||
"temperature": body.temperature,
|
||
"max_tokens": body.max_tokens,
|
||
"top_p": body.top_p,
|
||
},
|
||
)
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"初始化本地 LLM 流式调用失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
def generate_sse():
|
||
has_error = False
|
||
try:
|
||
for chunk in stream_iterator:
|
||
if chunk is None:
|
||
continue
|
||
|
||
if hasattr(chunk, "model_dump_json"):
|
||
chunk_json = chunk.model_dump_json()
|
||
else:
|
||
chunk_json = json.dumps(chunk, ensure_ascii=False)
|
||
yield f"data: {chunk_json}\n\n"
|
||
except GeneratorExit:
|
||
raise
|
||
except Exception as exc:
|
||
has_error = True
|
||
print(f"❌ /api/chat_callback 流式输出失败: {exc}")
|
||
|
||
if has_error:
|
||
print("⚠️ 已提前结束当前 SSE 流")
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_sse(),
|
||
status_code=200,
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"Access-Control-Allow-Origin": "*",
|
||
},
|
||
)
|