rtc-voice-chat/backend/tools/builtin/rag_tool.py
2026-04-02 09:40:23 +08:00

129 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识库检索工具 —— 将 RAG 作为 Function Calling 工具注册,
LLM 按需调用,而非每次请求都预检索。
"""
import json
import logging
import httpx
from security.signer import Signer
from tools.registry import tool_registry
from utils.env import env_bool, env_int, env_str
logger = logging.getLogger(__name__)
_KB_API_PATH = "/api/knowledge/collection/search_knowledge"
_KB_SERVICE = "air"
_KB_REGION = "cn-beijing"
def _retrieve_sync(query: str) -> str:
"""同步调用火山引擎知识库检索 API返回格式化后的知识片段。"""
access_key = env_str("CUSTOM_ACCESS_KEY_ID")
secret_key = env_str("CUSTOM_SECRET_KEY")
if not access_key or not secret_key:
return "知识库凭证未配置,无法检索"
endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com")
kb_resource_id = env_str("VOLC_KB_RESOURCE_ID")
kb_name = env_str("VOLC_KB_NAME")
kb_project = env_str("VOLC_KB_PROJECT", "default")
top_k = env_int("VOLC_KB_TOP_K", 3)
rerank = env_bool("VOLC_KB_RERANK", False)
attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False)
if not kb_resource_id and not kb_name:
return "知识库 ID 未配置,无法检索"
body: dict = {"query": query, "limit": top_k}
if kb_resource_id:
body["resource_id"] = kb_resource_id
else:
body["name"] = kb_name
body["project"] = kb_project
post_processing: dict = {}
if rerank:
post_processing["rerank_switch"] = True
if attachment_link:
post_processing["get_attachment_link"] = True
if post_processing:
body["post_processing"] = post_processing
headers = {"Content-Type": "application/json"}
signer = Signer(
request_data={
"region": _KB_REGION,
"method": "POST",
"path": _KB_API_PATH,
"headers": headers,
"body": body,
},
service=_KB_SERVICE,
)
signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key})
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
url = endpoint.rstrip("/") + _KB_API_PATH
resp = httpx.post(url, headers=headers, content=body_bytes, timeout=15)
resp_data = resp.json()
if resp_data.get("code") != 0:
msg = resp_data.get("message", "未知错误")
logger.warning("知识库检索失败: code=%s, message=%s", resp_data.get("code"), msg)
return f"知识库检索失败: {msg}"
result_list = resp_data.get("data", {}).get("result_list", [])
if not result_list:
return "知识库中未找到相关内容"
# 格式化为简洁文本
parts = []
for i, chunk in enumerate(result_list, start=1):
title = chunk.get("chunk_title") or ""
content = chunk.get("content") or ""
header = f"[{i}] {title}" if title else f"[{i}]"
block = f"{header}\n{content}"
if attachment_link:
attachments = chunk.get("chunk_attachment") or []
image_links = [a["link"] for a in attachments if a.get("link")]
if image_links:
block += "\n附件: " + ", ".join(image_links)
parts.append(block)
result = "\n\n".join(parts)
logger.info("知识库检索到 %d 条结果 (query=%s)", len(result_list), query[:50])
return result
# ── 注册工具 ──
if env_bool("VOLC_KB_ENABLED", False):
@tool_registry.register(
name="search_knowledge",
description="从知识库检索相关内容,用于回答系统功能介绍、操作说明、规章制度等问题。当用户询问的问题涉及业务知识而非实时数据时使用。",
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "检索关键词或问题",
},
},
"required": ["query"],
},
)
def search_knowledge(query: str) -> str:
try:
return _retrieve_sync(query)
except Exception as exc:
logger.exception("知识库检索异常")
return f"知识库检索出错: {exc}"
logger.info("✅ RAG 知识库工具已注册 (resource_id=%s)", env_str("VOLC_KB_RESOURCE_ID"))