From 817e35642ca988ef05c3b9dc6a14843a150217d1 Mon Sep 17 00:00:00 2001 From: lengbone <107662693+lengbone@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:39:19 +0800 Subject: [PATCH] add rag server --- backend/.env.example | 15 ++++- backend/routes/chat_callback.py | 12 ++++ backend/routes/debug.py | 2 + backend/security/signer.py | 3 +- backend/services/rag_service.py | 113 +++++++++++++++++++++++++++++--- 5 files changed, 134 insertions(+), 11 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index 4635521..942fb9b 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -56,8 +56,19 @@ ARK_TIMEOUT_SECONDS=1800 LOCAL_LLM_SYSTEM_PROMPT= "你是一个测试助手。如果别人问你是谁,你就说你是哈哈哈。" LOCAL_LLM_TEMPERATURE=0.3 -# 可选 RAG 占位配置 -# 当前首版默认未启用主链路 RAG,如需后续接入,可再填写这两个配置 +# RAG 配置 +# 方式一:火山引擎知识库(语义检索) +# 设置 VOLC_KB_ENABLED=true 后,每次对话前自动检索知识库并注入上下文 +VOLC_KB_ENABLED=false +VOLC_KB_NAME=your_collection_name # 知识库名称(与 VOLC_KB_RESOURCE_ID 二选一) +VOLC_KB_RESOURCE_ID= # 知识库唯一 ID(优先级高于 NAME) +VOLC_KB_PROJECT=default # 知识库所属项目 +VOLC_KB_ENDPOINT=https://api-knowledgebase.mlp.cn-beijing.volces.com +VOLC_KB_TOP_K=5 # 检索返回条数 +VOLC_KB_RERANK=false # 是否开启 rerank 重排 +VOLC_KB_ATTACHMENT_LINK=false # 是否返回图片临时链接(图文混合场景开启,链接有效期 10 分钟) + +# 方式二:静态上下文占位(VOLC_KB_ENABLED=false 时生效) RAG_STATIC_CONTEXT= RAG_CONTEXT_FILE= diff --git a/backend/routes/chat_callback.py b/backend/routes/chat_callback.py index cde9176..dd5367b 100644 --- a/backend/routes/chat_callback.py +++ b/backend/routes/chat_callback.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from services.local_llm_service import local_llm_service +from services.rag_service import rag_service from services.scene_service import ensure_custom_llm_authorized, get_custom_llm_callback_settings from utils.responses import custom_llm_error_response @@ -57,9 +58,20 @@ async def chat_callback(request: Request): status_code=400, ) + import time as _time + user_query = last_message.get("content", "") + _rag_t0 = _time.monotonic() + try: + rag_context = await rag_service.retrieve(user_query) + except Exception as exc: + print(f"⚠️ RAG 检索异常,已跳过: {exc}") + rag_context = "" + print(f"⏱️ RAG 检索耗时: {_time.monotonic() - _rag_t0:.2f}s") + try: stream_iterator = local_llm_service.chat_stream( history_messages=messages, + rag_context=rag_context, request_options={ "temperature": payload.get("temperature"), "max_tokens": payload.get("max_tokens"), diff --git a/backend/routes/debug.py b/backend/routes/debug.py index 8344cf2..b1b3ede 100644 --- a/backend/routes/debug.py +++ b/backend/routes/debug.py @@ -23,8 +23,10 @@ async def debug_chat(request: DebugChatRequest): current_messages.append({"role": "user", "content": request.question}) start_time = time.time() + rag_context = await rag_service.retrieve(request.question) stream_iterator = local_llm_service.chat_stream( history_messages=current_messages, + rag_context=rag_context, ) def generate_text(): diff --git a/backend/security/signer.py b/backend/security/signer.py index be1d68e..7b492b5 100644 --- a/backend/security/signer.py +++ b/backend/security/signer.py @@ -47,6 +47,7 @@ class Signer: """ self.region = request_data.get("region", "cn-north-1") self.method = request_data.get("method", "POST").upper() + self.path = request_data.get("path", "/") self.params = request_data.get("params", {}) self.headers = request_data.get("headers", {}) self.body = request_data.get("body", {}) @@ -84,7 +85,7 @@ class Signer: canonical_request = "\n".join( [ self.method, - "/", + self.path, canonical_qs, canonical_headers, signed_headers_str, diff --git a/backend/services/rag_service.py b/backend/services/rag_service.py index 65e589d..9ec2f75 100644 --- a/backend/services/rag_service.py +++ b/backend/services/rag_service.py @@ -1,30 +1,127 @@ """ -最小可用的 RAG 服务占位实现。 +RAG 服务。 -当前版本支持两种简单来源: -- RAG_STATIC_CONTEXT:直接写在环境变量中的固定知识 -- RAG_CONTEXT_FILE:读取本地文件全文作为知识上下文 +支持两种模式: +- VOLC_KB_ENABLED=false(默认):走旧的静态上下文逻辑(RAG_STATIC_CONTEXT / RAG_CONTEXT_FILE) +- VOLC_KB_ENABLED=true:调用火山引擎知识库 API 进行语义检索 -后续如果要接真正的向量检索,可以直接替换 retrieve 方法实现。 +出错时自动降级为空字符串,不阻断主链路 LLM 调用。 """ from __future__ import annotations +import json from pathlib import Path -from utils.env import env_str +import httpx + +from security.signer import Signer +from utils.env import env_bool, env_int, env_str + +_KB_API_PATH = "/api/knowledge/collection/search_knowledge" +_KB_SERVICE = "air" +_KB_REGION = "cn-beijing" + + +def _build_rag_context(result_list: list, with_attachment: bool) -> str: + parts = [] + for i, chunk in enumerate(result_list, start=1): + title = chunk.get("chunk_title") or "" + content = chunk.get("content") or "" + header = f"[{i}] {title}" if title else f"[{i}]" + block = f"{header}\n{content}" + + if with_attachment: + attachments = chunk.get("chunk_attachment") or [] + image_links = [a["link"] for a in attachments if a.get("link")] + if image_links: + block += "\n" + "\n".join(f"![图片]({link})" for link in image_links) + + parts.append(block) + return "\n\n".join(parts) + + +async def _retrieve_from_kb(query: str) -> str: + access_key = env_str("CUSTOM_ACCESS_KEY_ID") + secret_key = env_str("CUSTOM_SECRET_KEY") + if not access_key or not secret_key: + print("⚠️ RAG: CUSTOM_ACCESS_KEY_ID / CUSTOM_SECRET_KEY 未配置,跳过知识库检索") + return "" + + endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com") + top_k = env_int("VOLC_KB_TOP_K", 5) + rerank = env_bool("VOLC_KB_RERANK", False) + attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False) + + body: dict = {"query": query, "limit": top_k} + + kb_name = env_str("VOLC_KB_NAME") + kb_resource_id = env_str("VOLC_KB_RESOURCE_ID") + kb_project = env_str("VOLC_KB_PROJECT", "default") + + if kb_resource_id: + body["resource_id"] = kb_resource_id + elif kb_name: + body["name"] = kb_name + body["project"] = kb_project + else: + print("⚠️ RAG: VOLC_KB_NAME 或 VOLC_KB_RESOURCE_ID 未配置,跳过知识库检索") + return "" + + post_processing: dict = {} + if rerank: + post_processing["rerank_switch"] = True + if attachment_link: + post_processing["get_attachment_link"] = True + if post_processing: + body["post_processing"] = post_processing + + headers = {"Content-Type": "application/json"} + signer = Signer( + request_data={ + "region": _KB_REGION, + "method": "POST", + "path": _KB_API_PATH, + "headers": headers, + "body": body, + }, + service=_KB_SERVICE, + ) + signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key}) + + # 签名与发送必须用完全相同的序列化(紧凑无空格) + body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8") + url = endpoint.rstrip("/") + _KB_API_PATH + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post(url, headers=headers, content=body_bytes) + + resp_data = resp.json() + if resp_data.get("code") != 0: + print(f"⚠️ RAG: 知识库检索失败 code={resp_data.get('code')} message={resp_data.get('message')}") + return "" + + result_list = resp_data.get("data", {}).get("result_list", []) + print(f"✅ RAG: 检索到 {len(result_list)} 条知识片段") + return _build_rag_context(result_list, with_attachment=attachment_link) class RagService: async def retrieve(self, query: str) -> str: - _ = query + if env_bool("VOLC_KB_ENABLED", False): + try: + return await _retrieve_from_kb(query) + except Exception as exc: + print(f"⚠️ RAG: 知识库检索异常,已降级: {exc}") + return "" + + # 旧逻辑:静态上下文 context_file = env_str("RAG_CONTEXT_FILE") if context_file: path = Path(context_file).expanduser() if path.exists() and path.is_file(): return path.read_text(encoding="utf-8") - return env_str("RAG_STATIC_CONTEXT") + return env_str("RAG_STATIC_CONTEXT") or "" rag_service = RagService()