""" RAG 服务。 支持两种模式: - VOLC_KB_ENABLED=false(默认):走旧的静态上下文逻辑(RAG_STATIC_CONTEXT / RAG_CONTEXT_FILE) - VOLC_KB_ENABLED=true:调用火山引擎知识库 API 进行语义检索 出错时自动降级为空字符串,不阻断主链路 LLM 调用。 """ from __future__ import annotations import json from pathlib import Path import httpx from security.signer import Signer from utils.env import env_bool, env_int, env_str _KB_API_PATH = "/api/knowledge/collection/search_knowledge" _KB_SERVICE = "air" _KB_REGION = "cn-beijing" def _build_rag_context(result_list: list, with_attachment: bool) -> str: parts = [] for i, chunk in enumerate(result_list, start=1): title = chunk.get("chunk_title") or "" content = chunk.get("content") or "" header = f"[{i}] {title}" if title else f"[{i}]" block = f"{header}\n{content}" if with_attachment: attachments = chunk.get("chunk_attachment") or [] image_links = [a["link"] for a in attachments if a.get("link")] if image_links: block += "\n" + "\n".join(f"![图片]({link})" for link in image_links) parts.append(block) return "\n\n".join(parts) async def _retrieve_from_kb(query: str) -> str: access_key = env_str("CUSTOM_ACCESS_KEY_ID") secret_key = env_str("CUSTOM_SECRET_KEY") if not access_key or not secret_key: print("⚠️ RAG: CUSTOM_ACCESS_KEY_ID / CUSTOM_SECRET_KEY 未配置,跳过知识库检索") return "" endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com") top_k = env_int("VOLC_KB_TOP_K", 5) rerank = env_bool("VOLC_KB_RERANK", False) attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False) body: dict = {"query": query, "limit": top_k} kb_name = env_str("VOLC_KB_NAME") kb_resource_id = env_str("VOLC_KB_RESOURCE_ID") kb_project = env_str("VOLC_KB_PROJECT", "default") if kb_resource_id: body["resource_id"] = kb_resource_id elif kb_name: body["name"] = kb_name body["project"] = kb_project else: print("⚠️ RAG: VOLC_KB_NAME 或 VOLC_KB_RESOURCE_ID 未配置,跳过知识库检索") return "" post_processing: dict = {} if rerank: post_processing["rerank_switch"] = True if attachment_link: post_processing["get_attachment_link"] = True if post_processing: body["post_processing"] = post_processing headers = {"Content-Type": "application/json"} signer = Signer( request_data={ "region": _KB_REGION, "method": "POST", "path": _KB_API_PATH, "headers": headers, "body": body, }, service=_KB_SERVICE, ) signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key}) # 签名与发送必须用完全相同的序列化(紧凑无空格) body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8") url = endpoint.rstrip("/") + _KB_API_PATH async with httpx.AsyncClient(timeout=30) as client: resp = await client.post(url, headers=headers, content=body_bytes) resp_data = resp.json() if resp_data.get("code") != 0: print(f"⚠️ RAG: 知识库检索失败 code={resp_data.get('code')} message={resp_data.get('message')}") return "" result_list = resp_data.get("data", {}).get("result_list", []) print(f"✅ RAG: 检索到 {len(result_list)} 条知识片段") return _build_rag_context(result_list, with_attachment=attachment_link) class RagService: async def retrieve(self, query: str) -> str: if env_bool("VOLC_KB_ENABLED", False): try: return await _retrieve_from_kb(query) except Exception as exc: print(f"⚠️ RAG: 知识库检索异常,已降级: {exc}") return "" # 旧逻辑:静态上下文 context_file = env_str("RAG_CONTEXT_FILE") if context_file: path = Path(context_file).expanduser() if path.exists() and path.is_file(): return path.read_text(encoding="utf-8") return env_str("RAG_STATIC_CONTEXT") or "" rag_service = RagService()