165 lines
5.3 KiB
Python
165 lines
5.3 KiB
Python
"""Tool registry with @tool decorator and auto-discovery."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import pkgutil
|
|
from pathlib import Path
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from ..agent import Agent
|
|
from ..config import Config
|
|
from ..db import Database
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Global tool registry
|
|
_TOOLS: dict[str, ToolDef] = {}
|
|
|
|
|
|
class ToolDef:
|
|
"""Metadata for a registered tool."""
|
|
|
|
def __init__(self, name: str, description: str, func: Callable, category: str = "general"):
|
|
self.name = name
|
|
self.description = description
|
|
self.func = func
|
|
self.category = category
|
|
self.parameters = _extract_params(func)
|
|
|
|
def to_openai_schema(self) -> dict:
|
|
"""Convert to OpenAI function-calling format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": self.parameters["properties"],
|
|
"required": self.parameters["required"],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def tool(name: str, description: str, category: str = "general"):
|
|
"""Decorator to register a tool function."""
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
tool_def = ToolDef(name, description, func, category)
|
|
_TOOLS[name] = tool_def
|
|
func._tool_def = tool_def
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
def _extract_params(func: Callable) -> dict:
|
|
"""Extract parameter schema from function signature and type hints."""
|
|
sig = inspect.signature(func)
|
|
properties = {}
|
|
required = []
|
|
|
|
for param_name, param in sig.parameters.items():
|
|
if param_name in ("self", "ctx"):
|
|
continue
|
|
|
|
prop: dict[str, Any] = {}
|
|
annotation = param.annotation
|
|
|
|
if annotation == str or annotation == inspect.Parameter.empty:
|
|
prop["type"] = "string"
|
|
elif annotation == int:
|
|
prop["type"] = "integer"
|
|
elif annotation == float:
|
|
prop["type"] = "number"
|
|
elif annotation == bool:
|
|
prop["type"] = "boolean"
|
|
elif annotation == list:
|
|
prop["type"] = "array"
|
|
prop["items"] = {"type": "string"}
|
|
else:
|
|
prop["type"] = "string"
|
|
|
|
# Check for description in docstring (simple parsing)
|
|
prop["description"] = f"Parameter: {param_name}"
|
|
|
|
properties[param_name] = prop
|
|
|
|
if param.default is inspect.Parameter.empty:
|
|
required.append(param_name)
|
|
|
|
return {"properties": properties, "required": required}
|
|
|
|
|
|
class ToolRegistry:
|
|
"""Runtime tool registry with execution and schema generation."""
|
|
|
|
def __init__(self, config: "Config", db: "Database", agent: "Agent"):
|
|
self.config = config
|
|
self.db = db
|
|
self.agent = agent
|
|
self._discover_tools()
|
|
|
|
def _discover_tools(self):
|
|
"""Auto-import all modules in the tools/ package."""
|
|
tools_dir = Path(__file__).parent
|
|
for _, module_name, _ in pkgutil.iter_modules([str(tools_dir)]):
|
|
if module_name.startswith("_"):
|
|
continue
|
|
try:
|
|
importlib.import_module(f".{module_name}", package=__package__)
|
|
log.info("Loaded tool module: %s", module_name)
|
|
except Exception as e:
|
|
log.warning("Failed to load tool module %s: %s", module_name, e)
|
|
|
|
def get_tools_schema(self) -> list[dict]:
|
|
"""Get all tools in OpenAI function-calling format."""
|
|
return [t.to_openai_schema() for t in _TOOLS.values()]
|
|
|
|
def get_tools_description(self) -> str:
|
|
"""Human-readable tool list for system prompt."""
|
|
lines = []
|
|
by_cat: dict[str, list[ToolDef]] = {}
|
|
for t in _TOOLS.values():
|
|
by_cat.setdefault(t.category, []).append(t)
|
|
|
|
for cat, tools in sorted(by_cat.items()):
|
|
lines.append(f"\n### {cat.title()}")
|
|
for t in tools:
|
|
params = ", ".join(t.parameters["required"])
|
|
lines.append(f"- **{t.name}**({params}): {t.description}")
|
|
return "\n".join(lines)
|
|
|
|
def execute(self, name: str, args: dict) -> str:
|
|
"""Execute a tool by name and return the result as a string."""
|
|
if name not in _TOOLS:
|
|
return f"Unknown tool: {name}"
|
|
|
|
tool_def = _TOOLS[name]
|
|
try:
|
|
# Inject context if the function expects it
|
|
sig = inspect.signature(tool_def.func)
|
|
if "ctx" in sig.parameters:
|
|
args["ctx"] = {
|
|
"config": self.config,
|
|
"db": self.db,
|
|
"agent": self.agent,
|
|
"memory": self.agent._memory,
|
|
}
|
|
result = tool_def.func(**args)
|
|
return str(result) if result is not None else "Done."
|
|
except Exception as e:
|
|
log.error("Tool %s failed: %s", name, e, exc_info=True)
|
|
return f"Tool error: {e}"
|
|
|
|
def register_external(self, tool_def: ToolDef):
|
|
"""Register a dynamically created tool."""
|
|
_TOOLS[tool_def.name] = tool_def
|
|
log.info("Registered external tool: %s", tool_def.name)
|