""" 知识库检索工具 —— 将 RAG 作为 Function Calling 工具注册, LLM 按需调用,而非每次请求都预检索。 """ import json import logging import httpx from security.signer import Signer from tools.registry import tool_registry from utils.env import env_bool, env_int, env_str logger = logging.getLogger(__name__) _KB_API_PATH = "/api/knowledge/collection/search_knowledge" _KB_SERVICE = "air" _KB_REGION = "cn-beijing" def _retrieve_sync(query: str) -> str: """同步调用火山引擎知识库检索 API,返回格式化后的知识片段。""" access_key = env_str("CUSTOM_ACCESS_KEY_ID") secret_key = env_str("CUSTOM_SECRET_KEY") if not access_key or not secret_key: return "知识库凭证未配置,无法检索" endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com") kb_resource_id = env_str("VOLC_KB_RESOURCE_ID") kb_name = env_str("VOLC_KB_NAME") kb_project = env_str("VOLC_KB_PROJECT", "default") top_k = env_int("VOLC_KB_TOP_K", 3) rerank = env_bool("VOLC_KB_RERANK", False) attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False) if not kb_resource_id and not kb_name: return "知识库 ID 未配置,无法检索" body: dict = {"query": query, "limit": top_k} if kb_resource_id: body["resource_id"] = kb_resource_id else: body["name"] = kb_name body["project"] = kb_project post_processing: dict = {} if rerank: post_processing["rerank_switch"] = True if attachment_link: post_processing["get_attachment_link"] = True if post_processing: body["post_processing"] = post_processing headers = {"Content-Type": "application/json"} signer = Signer( request_data={ "region": _KB_REGION, "method": "POST", "path": _KB_API_PATH, "headers": headers, "body": body, }, service=_KB_SERVICE, ) signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key}) body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8") url = endpoint.rstrip("/") + _KB_API_PATH resp = httpx.post(url, headers=headers, content=body_bytes, timeout=15) resp_data = resp.json() if resp_data.get("code") != 0: msg = resp_data.get("message", "未知错误") logger.warning("知识库检索失败: code=%s, message=%s", resp_data.get("code"), msg) return f"知识库检索失败: {msg}" result_list = resp_data.get("data", {}).get("result_list", []) if not result_list: return "知识库中未找到相关内容" # 格式化为简洁文本 parts = [] for i, chunk in enumerate(result_list, start=1): title = chunk.get("chunk_title") or "" content = chunk.get("content") or "" header = f"[{i}] {title}" if title else f"[{i}]" block = f"{header}\n{content}" if attachment_link: attachments = chunk.get("chunk_attachment") or [] image_links = [a["link"] for a in attachments if a.get("link")] if image_links: block += "\n附件: " + ", ".join(image_links) parts.append(block) result = "\n\n".join(parts) logger.info("知识库检索到 %d 条结果 (query=%s)", len(result_list), query[:50]) return result # ── 注册工具 ── if env_bool("VOLC_KB_ENABLED", False): @tool_registry.register( name="search_knowledge", description="从知识库检索相关内容,用于回答系统功能介绍、操作说明、规章制度等问题。当用户询问的问题涉及业务知识而非实时数据时使用。", parameters={ "type": "object", "properties": { "query": { "type": "string", "description": "检索关键词或问题", }, }, "required": ["query"], }, ) def search_knowledge(query: str) -> str: try: return _retrieve_sync(query) except Exception as exc: logger.exception("知识库检索异常") return f"知识库检索出错: {exc}" logger.info("✅ RAG 知识库工具已注册 (resource_id=%s)", env_str("VOLC_KB_RESOURCE_ID"))