128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
"""
|
||
RAG 服务。
|
||
|
||
支持两种模式:
|
||
- VOLC_KB_ENABLED=false(默认):走旧的静态上下文逻辑(RAG_STATIC_CONTEXT / RAG_CONTEXT_FILE)
|
||
- VOLC_KB_ENABLED=true:调用火山引擎知识库 API 进行语义检索
|
||
|
||
出错时自动降级为空字符串,不阻断主链路 LLM 调用。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
|
||
from security.signer import Signer
|
||
from utils.env import env_bool, env_int, env_str
|
||
|
||
_KB_API_PATH = "/api/knowledge/collection/search_knowledge"
|
||
_KB_SERVICE = "air"
|
||
_KB_REGION = "cn-beijing"
|
||
|
||
|
||
def _build_rag_context(result_list: list, with_attachment: bool) -> str:
|
||
parts = []
|
||
for i, chunk in enumerate(result_list, start=1):
|
||
title = chunk.get("chunk_title") or ""
|
||
content = chunk.get("content") or ""
|
||
header = f"[{i}] {title}" if title else f"[{i}]"
|
||
block = f"{header}\n{content}"
|
||
|
||
if with_attachment:
|
||
attachments = chunk.get("chunk_attachment") or []
|
||
image_links = [a["link"] for a in attachments if a.get("link")]
|
||
if image_links:
|
||
block += "\n" + "\n".join(f"" for link in image_links)
|
||
|
||
parts.append(block)
|
||
return "\n\n".join(parts)
|
||
|
||
|
||
async def _retrieve_from_kb(query: str) -> str:
|
||
access_key = env_str("CUSTOM_ACCESS_KEY_ID")
|
||
secret_key = env_str("CUSTOM_SECRET_KEY")
|
||
if not access_key or not secret_key:
|
||
print("⚠️ RAG: CUSTOM_ACCESS_KEY_ID / CUSTOM_SECRET_KEY 未配置,跳过知识库检索")
|
||
return ""
|
||
|
||
endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com")
|
||
top_k = env_int("VOLC_KB_TOP_K", 5)
|
||
rerank = env_bool("VOLC_KB_RERANK", False)
|
||
attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False)
|
||
|
||
body: dict = {"query": query, "limit": top_k}
|
||
|
||
kb_name = env_str("VOLC_KB_NAME")
|
||
kb_resource_id = env_str("VOLC_KB_RESOURCE_ID")
|
||
kb_project = env_str("VOLC_KB_PROJECT", "default")
|
||
|
||
if kb_resource_id:
|
||
body["resource_id"] = kb_resource_id
|
||
elif kb_name:
|
||
body["name"] = kb_name
|
||
body["project"] = kb_project
|
||
else:
|
||
print("⚠️ RAG: VOLC_KB_NAME 或 VOLC_KB_RESOURCE_ID 未配置,跳过知识库检索")
|
||
return ""
|
||
|
||
post_processing: dict = {}
|
||
if rerank:
|
||
post_processing["rerank_switch"] = True
|
||
if attachment_link:
|
||
post_processing["get_attachment_link"] = True
|
||
if post_processing:
|
||
body["post_processing"] = post_processing
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
signer = Signer(
|
||
request_data={
|
||
"region": _KB_REGION,
|
||
"method": "POST",
|
||
"path": _KB_API_PATH,
|
||
"headers": headers,
|
||
"body": body,
|
||
},
|
||
service=_KB_SERVICE,
|
||
)
|
||
signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key})
|
||
|
||
# 签名与发送必须用完全相同的序列化(紧凑无空格)
|
||
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||
url = endpoint.rstrip("/") + _KB_API_PATH
|
||
async with httpx.AsyncClient(timeout=30) as client:
|
||
resp = await client.post(url, headers=headers, content=body_bytes)
|
||
|
||
resp_data = resp.json()
|
||
if resp_data.get("code") != 0:
|
||
print(f"⚠️ RAG: 知识库检索失败 code={resp_data.get('code')} message={resp_data.get('message')}")
|
||
return ""
|
||
|
||
result_list = resp_data.get("data", {}).get("result_list", [])
|
||
print(f"✅ RAG: 检索到 {len(result_list)} 条知识片段")
|
||
return _build_rag_context(result_list, with_attachment=attachment_link)
|
||
|
||
|
||
class RagService:
|
||
async def retrieve(self, query: str) -> str:
|
||
if env_bool("VOLC_KB_ENABLED", False):
|
||
try:
|
||
return await _retrieve_from_kb(query)
|
||
except Exception as exc:
|
||
print(f"⚠️ RAG: 知识库检索异常,已降级: {exc}")
|
||
return ""
|
||
|
||
# 旧逻辑:静态上下文
|
||
context_file = env_str("RAG_CONTEXT_FILE")
|
||
if context_file:
|
||
path = Path(context_file).expanduser()
|
||
if path.exists() and path.is_file():
|
||
return path.read_text(encoding="utf-8")
|
||
|
||
return env_str("RAG_STATIC_CONTEXT") or ""
|
||
|
||
|
||
rag_service = RagService()
|