64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ToolDef:
|
|
name: str
|
|
description: str
|
|
parameters: dict[str, Any]
|
|
handler: Callable[..., str]
|
|
|
|
|
|
class ToolRegistry:
|
|
def __init__(self):
|
|
self._tools: dict[str, ToolDef] = {}
|
|
|
|
def register(self, name: str, description: str, parameters: dict[str, Any]):
|
|
"""装饰器:将函数注册为工具。"""
|
|
|
|
def decorator(fn: Callable[..., str]) -> Callable[..., str]:
|
|
self._tools[name] = ToolDef(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters,
|
|
handler=fn,
|
|
)
|
|
logger.info("工具已注册: %s", name)
|
|
return fn
|
|
|
|
return decorator
|
|
|
|
def get_openai_tools(self) -> list[dict[str, Any]]:
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": t.name,
|
|
"description": t.description,
|
|
"parameters": t.parameters,
|
|
},
|
|
}
|
|
for t in self._tools.values()
|
|
]
|
|
|
|
def execute(self, name: str, arguments_json: str) -> str:
|
|
tool = self._tools.get(name)
|
|
if tool is None:
|
|
return f"错误:未知工具 '{name}'"
|
|
try:
|
|
kwargs = json.loads(arguments_json) if arguments_json.strip() else {}
|
|
return tool.handler(**kwargs)
|
|
except Exception as exc:
|
|
logger.exception("工具 %s 执行失败", name)
|
|
return f"工具 {name} 执行出错:{exc}"
|
|
|
|
|
|
tool_registry = ToolRegistry()
|