CheddahBot/cheddahbot/tools/__init__.py

170 lines
5.6 KiB
Python

"""Tool registry with @tool decorator and auto-discovery."""
from __future__ import annotations
import importlib
import inspect
import logging
import pkgutil
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ..agent import Agent
from ..config import Config
from ..db import Database
log = logging.getLogger(__name__)
# Global tool registry
_TOOLS: dict[str, ToolDef] = {}
class ToolDef:
"""Metadata for a registered tool."""
def __init__(self, name: str, description: str, func: Callable, category: str = "general"):
self.name = name
self.description = description
self.func = func
self.category = category
self.parameters = _extract_params(func)
def to_openai_schema(self) -> dict:
"""Convert to OpenAI function-calling format."""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": self.parameters["properties"],
"required": self.parameters["required"],
},
},
}
def tool(name: str, description: str, category: str = "general"):
"""Decorator to register a tool function."""
def decorator(func: Callable) -> Callable:
tool_def = ToolDef(name, description, func, category)
_TOOLS[name] = tool_def
func._tool_def = tool_def
return func
return decorator
def _extract_params(func: Callable) -> dict:
"""Extract parameter schema from function signature and type hints."""
sig = inspect.signature(func)
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name in ("self", "ctx"):
continue
prop: dict[str, Any] = {}
annotation = param.annotation
if annotation is str or annotation is inspect.Parameter.empty:
prop["type"] = "string"
elif annotation is int:
prop["type"] = "integer"
elif annotation is float:
prop["type"] = "number"
elif annotation is bool:
prop["type"] = "boolean"
elif annotation is list:
prop["type"] = "array"
prop["items"] = {"type": "string"}
else:
prop["type"] = "string"
# Check for description in docstring (simple parsing)
prop["description"] = f"Parameter: {param_name}"
properties[param_name] = prop
if param.default is inspect.Parameter.empty:
required.append(param_name)
return {"properties": properties, "required": required}
class ToolRegistry:
"""Runtime tool registry with execution and schema generation."""
def __init__(self, config: Config, db: Database, agent: Agent):
self.config = config
self.db = db
self.agent = agent
self._discover_tools()
def _discover_tools(self):
"""Auto-import all modules in the tools/ package."""
tools_dir = Path(__file__).parent
for _, module_name, _ in pkgutil.iter_modules([str(tools_dir)]):
if module_name.startswith("_"):
continue
try:
importlib.import_module(f".{module_name}", package=__package__)
log.info("Loaded tool module: %s", module_name)
except Exception as e:
log.warning("Failed to load tool module %s: %s", module_name, e)
def get_tools_schema(self, filter_names: list[str] | None = None) -> list[dict]:
"""Get tools in OpenAI function-calling format, optionally filtered."""
tools = _TOOLS.values()
if filter_names is not None:
tools = [t for t in tools if t.name in filter_names]
return [t.to_openai_schema() for t in tools]
def get_tools_description(self, filter_names: list[str] | None = None) -> str:
"""Human-readable tool list for system prompt, optionally filtered."""
lines = []
by_cat: dict[str, list[ToolDef]] = {}
for t in _TOOLS.values():
if filter_names is not None and t.name not in filter_names:
continue
by_cat.setdefault(t.category, []).append(t)
for cat, tools in sorted(by_cat.items()):
lines.append(f"\n### {cat.title()}")
for t in tools:
params = ", ".join(t.parameters["required"])
lines.append(f"- **{t.name}**({params}): {t.description}")
return "\n".join(lines)
def execute(self, name: str, args: dict) -> str:
"""Execute a tool by name and return the result as a string."""
if name not in _TOOLS:
return f"Unknown tool: {name}"
tool_def = _TOOLS[name]
try:
# Inject context if the function expects it
sig = inspect.signature(tool_def.func)
if "ctx" in sig.parameters:
args["ctx"] = {
"config": self.config,
"db": self.db,
"agent": self.agent,
"memory": self.agent._memory,
}
result = tool_def.func(**args)
return str(result) if result is not None else "Done."
except Exception as e:
log.error("Tool %s failed: %s", name, e, exc_info=True)
return f"Tool error: {e}"
def register_external(self, tool_def: ToolDef):
"""Register a dynamically created tool."""
_TOOLS[tool_def.name] = tool_def
log.info("Registered external tool: %s", tool_def.name)