86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
"""
|
||
调试端点:POST /debug/chat, GET /debug/rag
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
|
||
from fastapi import APIRouter
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from schemas.chat import DebugChatRequest
|
||
from services.local_llm_service import local_llm_service
|
||
from services.rag_service import rag_service
|
||
|
||
router = APIRouter(prefix="/debug")
|
||
|
||
|
||
@router.post("/chat")
|
||
async def debug_chat(request: DebugChatRequest):
|
||
current_messages = [
|
||
{"role": message.role, "content": message.content} for message in request.history
|
||
]
|
||
current_messages.append({"role": "user", "content": request.question})
|
||
|
||
start_time = time.time()
|
||
stream_iterator = local_llm_service.chat_stream(
|
||
history_messages=current_messages,
|
||
)
|
||
|
||
def generate_text():
|
||
full_ai_response = ""
|
||
total_usage = None
|
||
|
||
for chunk in stream_iterator:
|
||
if chunk is None:
|
||
continue
|
||
|
||
choices = getattr(chunk, "choices", None) or []
|
||
if choices:
|
||
delta = getattr(choices[0], "delta", None)
|
||
content = getattr(delta, "content", None)
|
||
if content:
|
||
full_ai_response += content
|
||
yield content
|
||
|
||
usage = getattr(chunk, "usage", None)
|
||
if usage:
|
||
total_usage = usage
|
||
|
||
print(f"DEBUG: LLM 调用耗时: {time.time() - start_time:.2f}s")
|
||
if total_usage:
|
||
print(
|
||
"🎫 Token 统计: "
|
||
f"Total={total_usage.total_tokens} "
|
||
f"(P:{total_usage.prompt_tokens}, C:{total_usage.completion_tokens})"
|
||
)
|
||
|
||
new_history = [
|
||
{"role": message.role, "content": message.content}
|
||
for message in request.history
|
||
]
|
||
new_history.append({"role": "user", "content": request.question})
|
||
new_history.append({"role": "assistant", "content": full_ai_response})
|
||
|
||
print("\n" + "=" * 50)
|
||
print("🐞 调试完成!以下是可用于下次请求的 history 结构:")
|
||
print(json.dumps({"history": new_history}, ensure_ascii=False, indent=2))
|
||
print("=" * 50 + "\n")
|
||
|
||
return StreamingResponse(generate_text(), media_type="text/plain")
|
||
|
||
|
||
@router.get("/rag")
|
||
async def debug_rag(query: str):
|
||
if not query:
|
||
return {"error": "请提供 query 参数"}
|
||
|
||
print(f"🔍 [Debug] 正在检索知识库: {query}")
|
||
context = await rag_service.retrieve(query)
|
||
return {
|
||
"query": query,
|
||
"retrieved_context": context,
|
||
"length": len(context) if context else 0,
|
||
"status": "success" if context else "no_results_or_error",
|
||
}
|