splinter-keep/tools/engine_lib/tools_handler.py
2026-07-01 23:01:01 +02:00

263 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import random
import re
from .paths import CHAR_PATH, WORLD_PATH
from .state import read_file, validate_update_size, update_journal, append_llm_log
TOOL_REGISTRY: dict[str, dict] = {
"roll": {"description": "Roll dice.", "args": {"dice": "1d6", "modifier": "+1"}},
"player_roll": {"description": "Ask player to roll.", "args": {"dice": "1d6", "reason": "why"}},
"modify_traits": {"description": "Change STR/DEX/WIL.", "args": {"str": "optional", "dex": "optional", "wil": "optional"}},
"modify_vitals": {"description": "Change HP, cash, weapon, armour.", "args": {"current_hp": "optional", "max_hp": "optional", "cash": "optional", "weapon": "optional", "armour": "optional"}},
"add_to_inventory": {"description": "Add item to gear.", "args": {"item": "item name and stats"}},
"remove_from_inventory": {"description": "Remove item from gear.", "args": {"item": "exact item text"}},
"replace_gear": {"description": "Replace gear by exact match.", "args": {"before": "exact text", "after": "new text"}},
"add_note": {"description": "Add note to sheet.", "args": {"note": "note content"}},
"replace_note": {"description": "Replace note by exact match.", "args": {"before": "exact text", "after": "new text"}},
"world_update": {"description": "Replace world state.", "args": {"content": "full world markdown"}},
"journal_update": {"description": "Update TODO/DONE.", "args": {"add": "[...]", "done": "[...]"}},
"finalize_turn": {"description": "End turn.", "args": {"ambience": "soundscape name", "log_entry": "one-line summary of what happened"}},
}
def patch_character(pattern: str, repl: str, count: int = 1, flags: int = 0) -> str:
"""Apply a regex replacement to character.md. Returns error msg or empty string."""
text = CHAR_PATH.read_text()
new, n = re.subn(pattern, repl, text, count=count, flags=flags)
if n == 0:
return f"**Error:** pattern not found:\n{pattern}"
CHAR_PATH.write_text(new)
return ""
def tool_roll(args: dict) -> str:
dice_str = (args or {}).get("dice", "1d6")
modifier_str = (args or {}).get("modifier", "0")
try:
count, sides = dice_str.lower().split("d")
count = int(count) if count else 1
sides = int(sides)
except (ValueError, TypeError):
return f"Invalid dice: {dice_str}. Use format like '2d6'."
mod = 0
if modifier_str:
try:
mod = int(modifier_str)
except ValueError:
pass
rolls = [random.randint(1, sides) for _ in range(count)]
total = sum(rolls) + mod
mod_str = f" {'+' if mod >= 0 else ''}{mod}" if mod != 0 else ""
return f"Roll: {dice_str}{mod_str} → [{', '.join(str(r) for r in rolls)}] = {total}"
def tool_modify_traits(args: dict) -> str:
errors = []
for stat in ("str", "dex", "wil"):
val = args.get(stat)
if val is not None:
err = patch_character(
rf"^(- \*\*{stat.upper()}:\*\*\s*)\d+", rf"\g<1>{val}", count=1, flags=re.MULTILINE
)
if err:
errors.append(err)
return "; ".join(errors) if errors else "Traits updated."
def tool_modify_vitals(args: dict) -> str:
errors = []
for field, label in [("current_hp", "Current Health"), ("max_hp", "Max Health"),
("cash", "Cash"), ("weapon", "Weapon"), ("armour", "Armour")]:
val = args.get(field)
if val is not None:
err = patch_character(
rf"^(- \*\*{label}:\*\*\s*).*", rf"\g<1>{val}", count=1, flags=re.MULTILINE
)
if err:
errors.append(err)
return "; ".join(errors) if errors else "Vitals updated."
def tool_add_to_inventory(args: dict) -> str:
item = (args or {}).get("item", "")
if not item:
return "**Error:** `item` is required."
text = CHAR_PATH.read_text()
if item in text:
return f"Item already in inventory: {item}"
gear_section = re.search(r"^## Gear\n", text, re.MULTILINE)
if gear_section:
insert_at = gear_section.end()
text = text[:insert_at] + f"- {item}\n" + text[insert_at:]
else:
text += f"\n## Gear\n- {item}\n"
CHAR_PATH.write_text(text)
return f"Added to inventory: {item}"
def tool_remove_from_inventory(args: dict) -> str:
item = (args or {}).get("item", "")
if not item:
return "**Error:** `item` is required."
err = patch_character(rf"^- {re.escape(item)}\n?", "", count=1, flags=re.MULTILINE)
if err:
return f"**Error:** item not found: {item}"
return f"Removed from inventory: {item}"
def tool_replace_gear(args: dict) -> str:
before = (args or {}).get("before", "")
after = (args or {}).get("after", "")
if not before or not after:
return "**Error:** `before` and `after` are required."
err = patch_character(rf"^- {re.escape(before)}", f"- {after}", count=1, flags=re.MULTILINE)
if err:
return f"**Error:** gear not found: {before}"
return f"Gear replaced: {before}{after}"
def tool_add_note(args: dict) -> str:
note = (args or {}).get("note", "")
if not note:
return "**Error:** `note` is required."
text = CHAR_PATH.read_text()
notes_section = re.search(r"^## Notes & Scribbles\n", text, re.MULTILINE)
if notes_section:
text = text[:notes_section.end()] + f"- {note}\n" + text[notes_section.end():]
else:
text += f"\n## Notes & Scribbles\n- {note}\n"
CHAR_PATH.write_text(text)
return f"Note added: {note}"
def tool_replace_note(args: dict) -> str:
before = (args or {}).get("before", "")
after = (args or {}).get("after", "")
if not before or not after:
return "**Error:** `before` and `after` are required."
err = patch_character(rf"^- {re.escape(before)}", f"- {after}", count=1, flags=re.MULTILINE)
if err:
return f"**Error:** note not found: {before}"
return f"Note replaced."
def tool_world_update(args: dict) -> str:
content = (args or {}).get("content", "")
if not content:
return "**Error:** `content` is required."
if not validate_update_size("world", content, WORLD_PATH):
return "**Error:** Update rejected — content is too short (likely a partial paste)."
WORLD_PATH.write_text(content.strip() + "\n")
return "World state updated."
def tool_journal_update(args: dict) -> str:
add = (args or {}).get("add", [])
done = (args or {}).get("done", [])
if isinstance(add, str):
add = [add]
if isinstance(done, str):
done = [done]
if not add and not done:
return "**Error:** Provide at least one of `add` or `done`."
update_journal(add=add, done=done)
return "Journal updated."
def execute_tool(tool_name: str, args: dict) -> str:
"""Execute a tool by name. Returns result string."""
fn_map = {
"roll": tool_roll,
"modify_traits": tool_modify_traits,
"modify_vitals": tool_modify_vitals,
"add_to_inventory": tool_add_to_inventory,
"remove_from_inventory": tool_remove_from_inventory,
"replace_gear": tool_replace_gear,
"add_note": tool_add_note,
"replace_note": tool_replace_note,
"world_update": tool_world_update,
"journal_update": tool_journal_update,
}
fn = fn_map.get(tool_name)
if not fn:
return f"Unknown tool: {tool_name}"
try:
return fn(args)
except Exception as e:
import traceback
tb = traceback.format_exc()
append_llm_log(f"\n--- TOOL ERROR ({tool_name}) ---\n{tb}")
return f"Tool error ({tool_name}): {e}"
def describe_change(tool_name: str, args: dict) -> str:
"""Build a compact human-readable change description from a tool call."""
if tool_name == "modify_vitals":
parts = []
for k, v in args.items():
label = k.replace("_", " ").title()
parts.append(f"{label}: {v}")
return f"{', '.join(parts)}" if parts else ""
elif tool_name == "modify_traits":
parts = []
for k, v in args.items():
parts.append(f"{k.upper()}: {v}")
return f"{', '.join(parts)}"
elif tool_name == "add_to_inventory":
return f"+ {args.get('item', '?')}"
elif tool_name == "remove_from_inventory":
return f" {args.get('item', '?')}"
elif tool_name == "replace_gear":
return f"{args.get('before', '?')}{args.get('after', '?')}"
elif tool_name == "add_note":
note = args.get("note", "?")
return f"📝 {note[:60]}{'' if len(note) > 60 else ''}"
elif tool_name == "replace_note":
return f"📝 {args.get('before', '?')[:40]}{args.get('after', '?')[:40]}"
elif tool_name == "world_update":
return "🌍 World updated"
elif tool_name == "journal_update":
parts = []
for a in args.get("add", []):
parts.append(f"📋 {a}")
for d in args.get("done", []):
parts.append(f"{d}")
return "; ".join(parts) if parts else ""
return ""
def extract_tool_calls(text: str, on_debug: callable = None) -> list[dict]:
"""Extract tool calls from ```tool blocks in LLM response."""
calls = []
seen = set()
for m in re.finditer(r"```tool\s*\n?", text):
try:
decoder = json.JSONDecoder()
obj, end = decoder.raw_decode(text, m.end())
except (json.JSONDecodeError, ValueError, StopIteration):
close = text.find("```", m.end())
if close > 0:
raw = text[m.end():close].strip()
raw = re.sub(r'"(?:[^"\\]|\\.)*"', lambda x: x.group(0).replace("\n", "\\n"), raw, flags=re.DOTALL)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
continue
else:
continue
if not isinstance(obj, dict) or "tool" not in obj:
continue
key = (obj["tool"], json.dumps(obj.get("args", {}), sort_keys=True))
if key not in seen:
seen.add(key)
calls.append(obj)
return calls