rtc-voice-chat/backend/services/rag_service.py
2026-03-30 10:39:19 +08:00

128 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 服务。
支持两种模式:
- VOLC_KB_ENABLED=false默认走旧的静态上下文逻辑RAG_STATIC_CONTEXT / RAG_CONTEXT_FILE
- VOLC_KB_ENABLED=true调用火山引擎知识库 API 进行语义检索
出错时自动降级为空字符串,不阻断主链路 LLM 调用。
"""
from __future__ import annotations
import json
from pathlib import Path
import httpx
from security.signer import Signer
from utils.env import env_bool, env_int, env_str
_KB_API_PATH = "/api/knowledge/collection/search_knowledge"
_KB_SERVICE = "air"
_KB_REGION = "cn-beijing"
def _build_rag_context(result_list: list, with_attachment: bool) -> str:
parts = []
for i, chunk in enumerate(result_list, start=1):
title = chunk.get("chunk_title") or ""
content = chunk.get("content") or ""
header = f"[{i}] {title}" if title else f"[{i}]"
block = f"{header}\n{content}"
if with_attachment:
attachments = chunk.get("chunk_attachment") or []
image_links = [a["link"] for a in attachments if a.get("link")]
if image_links:
block += "\n" + "\n".join(f"![图片]({link})" for link in image_links)
parts.append(block)
return "\n\n".join(parts)
async def _retrieve_from_kb(query: str) -> str:
access_key = env_str("CUSTOM_ACCESS_KEY_ID")
secret_key = env_str("CUSTOM_SECRET_KEY")
if not access_key or not secret_key:
print("⚠️ RAG: CUSTOM_ACCESS_KEY_ID / CUSTOM_SECRET_KEY 未配置,跳过知识库检索")
return ""
endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com")
top_k = env_int("VOLC_KB_TOP_K", 5)
rerank = env_bool("VOLC_KB_RERANK", False)
attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False)
body: dict = {"query": query, "limit": top_k}
kb_name = env_str("VOLC_KB_NAME")
kb_resource_id = env_str("VOLC_KB_RESOURCE_ID")
kb_project = env_str("VOLC_KB_PROJECT", "default")
if kb_resource_id:
body["resource_id"] = kb_resource_id
elif kb_name:
body["name"] = kb_name
body["project"] = kb_project
else:
print("⚠️ RAG: VOLC_KB_NAME 或 VOLC_KB_RESOURCE_ID 未配置,跳过知识库检索")
return ""
post_processing: dict = {}
if rerank:
post_processing["rerank_switch"] = True
if attachment_link:
post_processing["get_attachment_link"] = True
if post_processing:
body["post_processing"] = post_processing
headers = {"Content-Type": "application/json"}
signer = Signer(
request_data={
"region": _KB_REGION,
"method": "POST",
"path": _KB_API_PATH,
"headers": headers,
"body": body,
},
service=_KB_SERVICE,
)
signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key})
# 签名与发送必须用完全相同的序列化(紧凑无空格)
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
url = endpoint.rstrip("/") + _KB_API_PATH
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(url, headers=headers, content=body_bytes)
resp_data = resp.json()
if resp_data.get("code") != 0:
print(f"⚠️ RAG: 知识库检索失败 code={resp_data.get('code')} message={resp_data.get('message')}")
return ""
result_list = resp_data.get("data", {}).get("result_list", [])
print(f"✅ RAG: 检索到 {len(result_list)} 条知识片段")
return _build_rag_context(result_list, with_attachment=attachment_link)
class RagService:
async def retrieve(self, query: str) -> str:
if env_bool("VOLC_KB_ENABLED", False):
try:
return await _retrieve_from_kb(query)
except Exception as exc:
print(f"⚠️ RAG: 知识库检索异常,已降级: {exc}")
return ""
# 旧逻辑:静态上下文
context_file = env_str("RAG_CONTEXT_FILE")
if context_file:
path = Path(context_file).expanduser()
if path.exists() and path.is_file():
return path.read_text(encoding="utf-8")
return env_str("RAG_STATIC_CONTEXT") or ""
rag_service = RagService()