129 lines
4.3 KiB
Python
129 lines
4.3 KiB
Python
"""
|
||
知识库检索工具 —— 将 RAG 作为 Function Calling 工具注册,
|
||
LLM 按需调用,而非每次请求都预检索。
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
|
||
import httpx
|
||
|
||
from security.signer import Signer
|
||
from tools.registry import tool_registry
|
||
from utils.env import env_bool, env_int, env_str
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_KB_API_PATH = "/api/knowledge/collection/search_knowledge"
|
||
_KB_SERVICE = "air"
|
||
_KB_REGION = "cn-beijing"
|
||
|
||
|
||
def _retrieve_sync(query: str) -> str:
|
||
"""同步调用火山引擎知识库检索 API,返回格式化后的知识片段。"""
|
||
access_key = env_str("CUSTOM_ACCESS_KEY_ID")
|
||
secret_key = env_str("CUSTOM_SECRET_KEY")
|
||
if not access_key or not secret_key:
|
||
return "知识库凭证未配置,无法检索"
|
||
|
||
endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com")
|
||
kb_resource_id = env_str("VOLC_KB_RESOURCE_ID")
|
||
kb_name = env_str("VOLC_KB_NAME")
|
||
kb_project = env_str("VOLC_KB_PROJECT", "default")
|
||
top_k = env_int("VOLC_KB_TOP_K", 3)
|
||
rerank = env_bool("VOLC_KB_RERANK", False)
|
||
attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False)
|
||
|
||
if not kb_resource_id and not kb_name:
|
||
return "知识库 ID 未配置,无法检索"
|
||
|
||
body: dict = {"query": query, "limit": top_k}
|
||
if kb_resource_id:
|
||
body["resource_id"] = kb_resource_id
|
||
else:
|
||
body["name"] = kb_name
|
||
body["project"] = kb_project
|
||
|
||
post_processing: dict = {}
|
||
if rerank:
|
||
post_processing["rerank_switch"] = True
|
||
if attachment_link:
|
||
post_processing["get_attachment_link"] = True
|
||
if post_processing:
|
||
body["post_processing"] = post_processing
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
signer = Signer(
|
||
request_data={
|
||
"region": _KB_REGION,
|
||
"method": "POST",
|
||
"path": _KB_API_PATH,
|
||
"headers": headers,
|
||
"body": body,
|
||
},
|
||
service=_KB_SERVICE,
|
||
)
|
||
signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key})
|
||
|
||
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||
url = endpoint.rstrip("/") + _KB_API_PATH
|
||
|
||
resp = httpx.post(url, headers=headers, content=body_bytes, timeout=15)
|
||
resp_data = resp.json()
|
||
|
||
if resp_data.get("code") != 0:
|
||
msg = resp_data.get("message", "未知错误")
|
||
logger.warning("知识库检索失败: code=%s, message=%s", resp_data.get("code"), msg)
|
||
return f"知识库检索失败: {msg}"
|
||
|
||
result_list = resp_data.get("data", {}).get("result_list", [])
|
||
if not result_list:
|
||
return "知识库中未找到相关内容"
|
||
|
||
# 格式化为简洁文本
|
||
parts = []
|
||
for i, chunk in enumerate(result_list, start=1):
|
||
title = chunk.get("chunk_title") or ""
|
||
content = chunk.get("content") or ""
|
||
header = f"[{i}] {title}" if title else f"[{i}]"
|
||
block = f"{header}\n{content}"
|
||
|
||
if attachment_link:
|
||
attachments = chunk.get("chunk_attachment") or []
|
||
image_links = [a["link"] for a in attachments if a.get("link")]
|
||
if image_links:
|
||
block += "\n附件: " + ", ".join(image_links)
|
||
|
||
parts.append(block)
|
||
|
||
result = "\n\n".join(parts)
|
||
logger.info("知识库检索到 %d 条结果 (query=%s)", len(result_list), query[:50])
|
||
return result
|
||
|
||
|
||
# ── 注册工具 ──
|
||
|
||
if env_bool("VOLC_KB_ENABLED", False):
|
||
@tool_registry.register(
|
||
name="search_knowledge",
|
||
description="从知识库检索相关内容,用于回答系统功能介绍、操作说明、规章制度等问题。当用户询问的问题涉及业务知识而非实时数据时使用。",
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "检索关键词或问题",
|
||
},
|
||
},
|
||
"required": ["query"],
|
||
},
|
||
)
|
||
def search_knowledge(query: str) -> str:
|
||
try:
|
||
return _retrieve_sync(query)
|
||
except Exception as exc:
|
||
logger.exception("知识库检索异常")
|
||
return f"知识库检索出错: {exc}"
|
||
|
||
logger.info("✅ RAG 知识库工具已注册 (resource_id=%s)", env_str("VOLC_KB_RESOURCE_ID"))
|