109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
"""
|
||
POST /api/chat_callback — 自定义 LLM 回调(SSE 流式响应)
|
||
"""
|
||
|
||
import json
|
||
|
||
from fastapi import APIRouter, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from services.local_llm_service import local_llm_service
|
||
from services.scene_service import ensure_custom_llm_authorized, get_custom_llm_callback_settings
|
||
from utils.responses import custom_llm_error_response
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/api/chat_callback")
|
||
async def chat_callback(request: Request):
|
||
try:
|
||
settings = get_custom_llm_callback_settings()
|
||
ensure_custom_llm_authorized(request, settings["api_key"])
|
||
payload = await request.json()
|
||
except PermissionError as exc:
|
||
return custom_llm_error_response(
|
||
str(exc),
|
||
code="AuthenticationError",
|
||
status_code=401,
|
||
)
|
||
except json.JSONDecodeError:
|
||
return custom_llm_error_response(
|
||
"请求体必须是合法的 JSON",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
except ValueError as exc:
|
||
return custom_llm_error_response(str(exc))
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"解析请求失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
messages = payload.get("messages")
|
||
if not isinstance(messages, list) or not messages:
|
||
return custom_llm_error_response(
|
||
"messages 不能为空",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
last_message = messages[-1]
|
||
if last_message.get("role") != "user":
|
||
return custom_llm_error_response(
|
||
"最后一条消息必须是用户消息",
|
||
code="BadRequest",
|
||
status_code=400,
|
||
)
|
||
|
||
try:
|
||
stream_iterator = local_llm_service.chat_stream(
|
||
history_messages=messages,
|
||
request_options={
|
||
"temperature": payload.get("temperature"),
|
||
"max_tokens": payload.get("max_tokens"),
|
||
"top_p": payload.get("top_p"),
|
||
},
|
||
)
|
||
except Exception as exc:
|
||
return custom_llm_error_response(
|
||
f"初始化本地 LLM 流式调用失败: {exc}",
|
||
code="InternalError",
|
||
status_code=500,
|
||
)
|
||
|
||
def generate_sse():
|
||
has_error = False
|
||
try:
|
||
for chunk in stream_iterator:
|
||
if chunk is None:
|
||
continue
|
||
|
||
if hasattr(chunk, "model_dump_json"):
|
||
chunk_json = chunk.model_dump_json()
|
||
else:
|
||
chunk_json = json.dumps(chunk, ensure_ascii=False)
|
||
yield f"data: {chunk_json}\n\n"
|
||
except GeneratorExit:
|
||
raise
|
||
except Exception as exc:
|
||
has_error = True
|
||
print(f"❌ /api/chat_callback 流式输出失败: {exc}")
|
||
|
||
if has_error:
|
||
print("⚠️ 已提前结束当前 SSE 流")
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
generate_sse(),
|
||
status_code=200,
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"Access-Control-Allow-Origin": "*",
|
||
},
|
||
)
|