add rag server
This commit is contained in:
parent
9a5cc3e6a0
commit
817e35642c
@ -56,8 +56,19 @@ ARK_TIMEOUT_SECONDS=1800
|
||||
LOCAL_LLM_SYSTEM_PROMPT= "你是一个测试助手。如果别人问你是谁,你就说你是哈哈哈。"
|
||||
LOCAL_LLM_TEMPERATURE=0.3
|
||||
|
||||
# 可选 RAG 占位配置
|
||||
# 当前首版默认未启用主链路 RAG,如需后续接入,可再填写这两个配置
|
||||
# RAG 配置
|
||||
# 方式一:火山引擎知识库(语义检索)
|
||||
# 设置 VOLC_KB_ENABLED=true 后,每次对话前自动检索知识库并注入上下文
|
||||
VOLC_KB_ENABLED=false
|
||||
VOLC_KB_NAME=your_collection_name # 知识库名称(与 VOLC_KB_RESOURCE_ID 二选一)
|
||||
VOLC_KB_RESOURCE_ID= # 知识库唯一 ID(优先级高于 NAME)
|
||||
VOLC_KB_PROJECT=default # 知识库所属项目
|
||||
VOLC_KB_ENDPOINT=https://api-knowledgebase.mlp.cn-beijing.volces.com
|
||||
VOLC_KB_TOP_K=5 # 检索返回条数
|
||||
VOLC_KB_RERANK=false # 是否开启 rerank 重排
|
||||
VOLC_KB_ATTACHMENT_LINK=false # 是否返回图片临时链接(图文混合场景开启,链接有效期 10 分钟)
|
||||
|
||||
# 方式二:静态上下文占位(VOLC_KB_ENABLED=false 时生效)
|
||||
RAG_STATIC_CONTEXT=
|
||||
RAG_CONTEXT_FILE=
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from services.local_llm_service import local_llm_service
|
||||
from services.rag_service import rag_service
|
||||
from services.scene_service import ensure_custom_llm_authorized, get_custom_llm_callback_settings
|
||||
from utils.responses import custom_llm_error_response
|
||||
|
||||
@ -57,9 +58,20 @@ async def chat_callback(request: Request):
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
import time as _time
|
||||
user_query = last_message.get("content", "")
|
||||
_rag_t0 = _time.monotonic()
|
||||
try:
|
||||
rag_context = await rag_service.retrieve(user_query)
|
||||
except Exception as exc:
|
||||
print(f"⚠️ RAG 检索异常,已跳过: {exc}")
|
||||
rag_context = ""
|
||||
print(f"⏱️ RAG 检索耗时: {_time.monotonic() - _rag_t0:.2f}s")
|
||||
|
||||
try:
|
||||
stream_iterator = local_llm_service.chat_stream(
|
||||
history_messages=messages,
|
||||
rag_context=rag_context,
|
||||
request_options={
|
||||
"temperature": payload.get("temperature"),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
|
||||
@ -23,8 +23,10 @@ async def debug_chat(request: DebugChatRequest):
|
||||
current_messages.append({"role": "user", "content": request.question})
|
||||
|
||||
start_time = time.time()
|
||||
rag_context = await rag_service.retrieve(request.question)
|
||||
stream_iterator = local_llm_service.chat_stream(
|
||||
history_messages=current_messages,
|
||||
rag_context=rag_context,
|
||||
)
|
||||
|
||||
def generate_text():
|
||||
|
||||
@ -47,6 +47,7 @@ class Signer:
|
||||
"""
|
||||
self.region = request_data.get("region", "cn-north-1")
|
||||
self.method = request_data.get("method", "POST").upper()
|
||||
self.path = request_data.get("path", "/")
|
||||
self.params = request_data.get("params", {})
|
||||
self.headers = request_data.get("headers", {})
|
||||
self.body = request_data.get("body", {})
|
||||
@ -84,7 +85,7 @@ class Signer:
|
||||
canonical_request = "\n".join(
|
||||
[
|
||||
self.method,
|
||||
"/",
|
||||
self.path,
|
||||
canonical_qs,
|
||||
canonical_headers,
|
||||
signed_headers_str,
|
||||
|
||||
@ -1,30 +1,127 @@
|
||||
"""
|
||||
最小可用的 RAG 服务占位实现。
|
||||
RAG 服务。
|
||||
|
||||
当前版本支持两种简单来源:
|
||||
- RAG_STATIC_CONTEXT:直接写在环境变量中的固定知识
|
||||
- RAG_CONTEXT_FILE:读取本地文件全文作为知识上下文
|
||||
支持两种模式:
|
||||
- VOLC_KB_ENABLED=false(默认):走旧的静态上下文逻辑(RAG_STATIC_CONTEXT / RAG_CONTEXT_FILE)
|
||||
- VOLC_KB_ENABLED=true:调用火山引擎知识库 API 进行语义检索
|
||||
|
||||
后续如果要接真正的向量检索,可以直接替换 retrieve 方法实现。
|
||||
出错时自动降级为空字符串,不阻断主链路 LLM 调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from utils.env import env_str
|
||||
import httpx
|
||||
|
||||
from security.signer import Signer
|
||||
from utils.env import env_bool, env_int, env_str
|
||||
|
||||
_KB_API_PATH = "/api/knowledge/collection/search_knowledge"
|
||||
_KB_SERVICE = "air"
|
||||
_KB_REGION = "cn-beijing"
|
||||
|
||||
|
||||
def _build_rag_context(result_list: list, with_attachment: bool) -> str:
|
||||
parts = []
|
||||
for i, chunk in enumerate(result_list, start=1):
|
||||
title = chunk.get("chunk_title") or ""
|
||||
content = chunk.get("content") or ""
|
||||
header = f"[{i}] {title}" if title else f"[{i}]"
|
||||
block = f"{header}\n{content}"
|
||||
|
||||
if with_attachment:
|
||||
attachments = chunk.get("chunk_attachment") or []
|
||||
image_links = [a["link"] for a in attachments if a.get("link")]
|
||||
if image_links:
|
||||
block += "\n" + "\n".join(f"" for link in image_links)
|
||||
|
||||
parts.append(block)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def _retrieve_from_kb(query: str) -> str:
|
||||
access_key = env_str("CUSTOM_ACCESS_KEY_ID")
|
||||
secret_key = env_str("CUSTOM_SECRET_KEY")
|
||||
if not access_key or not secret_key:
|
||||
print("⚠️ RAG: CUSTOM_ACCESS_KEY_ID / CUSTOM_SECRET_KEY 未配置,跳过知识库检索")
|
||||
return ""
|
||||
|
||||
endpoint = env_str("VOLC_KB_ENDPOINT", "https://api-knowledgebase.mlp.cn-beijing.volces.com")
|
||||
top_k = env_int("VOLC_KB_TOP_K", 5)
|
||||
rerank = env_bool("VOLC_KB_RERANK", False)
|
||||
attachment_link = env_bool("VOLC_KB_ATTACHMENT_LINK", False)
|
||||
|
||||
body: dict = {"query": query, "limit": top_k}
|
||||
|
||||
kb_name = env_str("VOLC_KB_NAME")
|
||||
kb_resource_id = env_str("VOLC_KB_RESOURCE_ID")
|
||||
kb_project = env_str("VOLC_KB_PROJECT", "default")
|
||||
|
||||
if kb_resource_id:
|
||||
body["resource_id"] = kb_resource_id
|
||||
elif kb_name:
|
||||
body["name"] = kb_name
|
||||
body["project"] = kb_project
|
||||
else:
|
||||
print("⚠️ RAG: VOLC_KB_NAME 或 VOLC_KB_RESOURCE_ID 未配置,跳过知识库检索")
|
||||
return ""
|
||||
|
||||
post_processing: dict = {}
|
||||
if rerank:
|
||||
post_processing["rerank_switch"] = True
|
||||
if attachment_link:
|
||||
post_processing["get_attachment_link"] = True
|
||||
if post_processing:
|
||||
body["post_processing"] = post_processing
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
signer = Signer(
|
||||
request_data={
|
||||
"region": _KB_REGION,
|
||||
"method": "POST",
|
||||
"path": _KB_API_PATH,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
},
|
||||
service=_KB_SERVICE,
|
||||
)
|
||||
signer.add_authorization({"accessKeyId": access_key, "secretKey": secret_key})
|
||||
|
||||
# 签名与发送必须用完全相同的序列化(紧凑无空格)
|
||||
body_bytes = json.dumps(body, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||||
url = endpoint.rstrip("/") + _KB_API_PATH
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(url, headers=headers, content=body_bytes)
|
||||
|
||||
resp_data = resp.json()
|
||||
if resp_data.get("code") != 0:
|
||||
print(f"⚠️ RAG: 知识库检索失败 code={resp_data.get('code')} message={resp_data.get('message')}")
|
||||
return ""
|
||||
|
||||
result_list = resp_data.get("data", {}).get("result_list", [])
|
||||
print(f"✅ RAG: 检索到 {len(result_list)} 条知识片段")
|
||||
return _build_rag_context(result_list, with_attachment=attachment_link)
|
||||
|
||||
|
||||
class RagService:
|
||||
async def retrieve(self, query: str) -> str:
|
||||
_ = query
|
||||
if env_bool("VOLC_KB_ENABLED", False):
|
||||
try:
|
||||
return await _retrieve_from_kb(query)
|
||||
except Exception as exc:
|
||||
print(f"⚠️ RAG: 知识库检索异常,已降级: {exc}")
|
||||
return ""
|
||||
|
||||
# 旧逻辑:静态上下文
|
||||
context_file = env_str("RAG_CONTEXT_FILE")
|
||||
if context_file:
|
||||
path = Path(context_file).expanduser()
|
||||
if path.exists() and path.is_file():
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
return env_str("RAG_STATIC_CONTEXT")
|
||||
return env_str("RAG_STATIC_CONTEXT") or ""
|
||||
|
||||
|
||||
rag_service = RagService()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user