263 lines
10 KiB
Python
263 lines
10 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import random
|
||
import re
|
||
|
||
from .paths import CHAR_PATH, WORLD_PATH
|
||
from .state import read_file, validate_update_size, update_journal, append_llm_log
|
||
|
||
|
||
TOOL_REGISTRY: dict[str, dict] = {
|
||
"roll": {"description": "Roll dice.", "args": {"dice": "1d6", "modifier": "+1"}},
|
||
"player_roll": {"description": "Ask player to roll.", "args": {"dice": "1d6", "reason": "why"}},
|
||
"modify_traits": {"description": "Change STR/DEX/WIL.", "args": {"str": "optional", "dex": "optional", "wil": "optional"}},
|
||
"modify_vitals": {"description": "Change HP, cash, weapon, armour.", "args": {"current_hp": "optional", "max_hp": "optional", "cash": "optional", "weapon": "optional", "armour": "optional"}},
|
||
"add_to_inventory": {"description": "Add item to gear.", "args": {"item": "item name and stats"}},
|
||
"remove_from_inventory": {"description": "Remove item from gear.", "args": {"item": "exact item text"}},
|
||
"replace_gear": {"description": "Replace gear by exact match.", "args": {"before": "exact text", "after": "new text"}},
|
||
"add_note": {"description": "Add note to sheet.", "args": {"note": "note content"}},
|
||
"replace_note": {"description": "Replace note by exact match.", "args": {"before": "exact text", "after": "new text"}},
|
||
"world_update": {"description": "Replace world state.", "args": {"content": "full world markdown"}},
|
||
"journal_update": {"description": "Update TODO/DONE.", "args": {"add": "[...]", "done": "[...]"}},
|
||
"finalize_turn": {"description": "End turn.", "args": {"ambience": "soundscape name", "log_entry": "one-line summary of what happened"}},
|
||
}
|
||
|
||
|
||
def patch_character(pattern: str, repl: str, count: int = 1, flags: int = 0) -> str:
|
||
"""Apply a regex replacement to character.md. Returns error msg or empty string."""
|
||
text = CHAR_PATH.read_text()
|
||
new, n = re.subn(pattern, repl, text, count=count, flags=flags)
|
||
if n == 0:
|
||
return f"**Error:** pattern not found:\n{pattern}"
|
||
CHAR_PATH.write_text(new)
|
||
return ""
|
||
|
||
|
||
def tool_roll(args: dict) -> str:
|
||
dice_str = (args or {}).get("dice", "1d6")
|
||
modifier_str = (args or {}).get("modifier", "0")
|
||
try:
|
||
count, sides = dice_str.lower().split("d")
|
||
count = int(count) if count else 1
|
||
sides = int(sides)
|
||
except (ValueError, TypeError):
|
||
return f"Invalid dice: {dice_str}. Use format like '2d6'."
|
||
mod = 0
|
||
if modifier_str:
|
||
try:
|
||
mod = int(modifier_str)
|
||
except ValueError:
|
||
pass
|
||
rolls = [random.randint(1, sides) for _ in range(count)]
|
||
total = sum(rolls) + mod
|
||
mod_str = f" {'+' if mod >= 0 else ''}{mod}" if mod != 0 else ""
|
||
return f"Roll: {dice_str}{mod_str} → [{', '.join(str(r) for r in rolls)}] = {total}"
|
||
|
||
|
||
def tool_modify_traits(args: dict) -> str:
|
||
errors = []
|
||
for stat in ("str", "dex", "wil"):
|
||
val = args.get(stat)
|
||
if val is not None:
|
||
err = patch_character(
|
||
rf"^(- \*\*{stat.upper()}:\*\*\s*)\d+", rf"\g<1>{val}", count=1, flags=re.MULTILINE
|
||
)
|
||
if err:
|
||
errors.append(err)
|
||
return "; ".join(errors) if errors else "Traits updated."
|
||
|
||
|
||
def tool_modify_vitals(args: dict) -> str:
|
||
errors = []
|
||
for field, label in [("current_hp", "Current Health"), ("max_hp", "Max Health"),
|
||
("cash", "Cash"), ("weapon", "Weapon"), ("armour", "Armour")]:
|
||
val = args.get(field)
|
||
if val is not None:
|
||
err = patch_character(
|
||
rf"^(- \*\*{label}:\*\*\s*).*", rf"\g<1>{val}", count=1, flags=re.MULTILINE
|
||
)
|
||
if err:
|
||
errors.append(err)
|
||
return "; ".join(errors) if errors else "Vitals updated."
|
||
|
||
|
||
def tool_add_to_inventory(args: dict) -> str:
|
||
item = (args or {}).get("item", "")
|
||
if not item:
|
||
return "**Error:** `item` is required."
|
||
text = CHAR_PATH.read_text()
|
||
if item in text:
|
||
return f"Item already in inventory: {item}"
|
||
gear_section = re.search(r"^## Gear\n", text, re.MULTILINE)
|
||
if gear_section:
|
||
insert_at = gear_section.end()
|
||
text = text[:insert_at] + f"- {item}\n" + text[insert_at:]
|
||
else:
|
||
text += f"\n## Gear\n- {item}\n"
|
||
CHAR_PATH.write_text(text)
|
||
return f"Added to inventory: {item}"
|
||
|
||
|
||
def tool_remove_from_inventory(args: dict) -> str:
|
||
item = (args or {}).get("item", "")
|
||
if not item:
|
||
return "**Error:** `item` is required."
|
||
err = patch_character(rf"^- {re.escape(item)}\n?", "", count=1, flags=re.MULTILINE)
|
||
if err:
|
||
return f"**Error:** item not found: {item}"
|
||
return f"Removed from inventory: {item}"
|
||
|
||
|
||
def tool_replace_gear(args: dict) -> str:
|
||
before = (args or {}).get("before", "")
|
||
after = (args or {}).get("after", "")
|
||
if not before or not after:
|
||
return "**Error:** `before` and `after` are required."
|
||
err = patch_character(rf"^- {re.escape(before)}", f"- {after}", count=1, flags=re.MULTILINE)
|
||
if err:
|
||
return f"**Error:** gear not found: {before}"
|
||
return f"Gear replaced: {before} → {after}"
|
||
|
||
|
||
def tool_add_note(args: dict) -> str:
|
||
note = (args or {}).get("note", "")
|
||
if not note:
|
||
return "**Error:** `note` is required."
|
||
text = CHAR_PATH.read_text()
|
||
notes_section = re.search(r"^## Notes & Scribbles\n", text, re.MULTILINE)
|
||
if notes_section:
|
||
text = text[:notes_section.end()] + f"- {note}\n" + text[notes_section.end():]
|
||
else:
|
||
text += f"\n## Notes & Scribbles\n- {note}\n"
|
||
CHAR_PATH.write_text(text)
|
||
return f"Note added: {note}"
|
||
|
||
|
||
def tool_replace_note(args: dict) -> str:
|
||
before = (args or {}).get("before", "")
|
||
after = (args or {}).get("after", "")
|
||
if not before or not after:
|
||
return "**Error:** `before` and `after` are required."
|
||
err = patch_character(rf"^- {re.escape(before)}", f"- {after}", count=1, flags=re.MULTILINE)
|
||
if err:
|
||
return f"**Error:** note not found: {before}"
|
||
return f"Note replaced."
|
||
|
||
|
||
def tool_world_update(args: dict) -> str:
|
||
content = (args or {}).get("content", "")
|
||
if not content:
|
||
return "**Error:** `content` is required."
|
||
if not validate_update_size("world", content, WORLD_PATH):
|
||
return "**Error:** Update rejected — content is too short (likely a partial paste)."
|
||
WORLD_PATH.write_text(content.strip() + "\n")
|
||
return "World state updated."
|
||
|
||
|
||
def tool_journal_update(args: dict) -> str:
|
||
add = (args or {}).get("add", [])
|
||
done = (args or {}).get("done", [])
|
||
if isinstance(add, str):
|
||
add = [add]
|
||
if isinstance(done, str):
|
||
done = [done]
|
||
if not add and not done:
|
||
return "**Error:** Provide at least one of `add` or `done`."
|
||
update_journal(add=add, done=done)
|
||
return "Journal updated."
|
||
|
||
|
||
def execute_tool(tool_name: str, args: dict) -> str:
|
||
"""Execute a tool by name. Returns result string."""
|
||
fn_map = {
|
||
"roll": tool_roll,
|
||
"modify_traits": tool_modify_traits,
|
||
"modify_vitals": tool_modify_vitals,
|
||
"add_to_inventory": tool_add_to_inventory,
|
||
"remove_from_inventory": tool_remove_from_inventory,
|
||
"replace_gear": tool_replace_gear,
|
||
"add_note": tool_add_note,
|
||
"replace_note": tool_replace_note,
|
||
"world_update": tool_world_update,
|
||
"journal_update": tool_journal_update,
|
||
}
|
||
fn = fn_map.get(tool_name)
|
||
if not fn:
|
||
return f"Unknown tool: {tool_name}"
|
||
try:
|
||
return fn(args)
|
||
except Exception as e:
|
||
import traceback
|
||
tb = traceback.format_exc()
|
||
append_llm_log(f"\n--- TOOL ERROR ({tool_name}) ---\n{tb}")
|
||
return f"Tool error ({tool_name}): {e}"
|
||
|
||
|
||
def describe_change(tool_name: str, args: dict) -> str:
|
||
"""Build a compact human-readable change description from a tool call."""
|
||
if tool_name == "modify_vitals":
|
||
parts = []
|
||
for k, v in args.items():
|
||
label = k.replace("_", " ").title()
|
||
parts.append(f"{label}: {v}")
|
||
return f"⚡ {', '.join(parts)}" if parts else ""
|
||
elif tool_name == "modify_traits":
|
||
parts = []
|
||
for k, v in args.items():
|
||
parts.append(f"{k.upper()}: {v}")
|
||
return f"⚡ {', '.join(parts)}"
|
||
elif tool_name == "add_to_inventory":
|
||
return f"+ {args.get('item', '?')}"
|
||
elif tool_name == "remove_from_inventory":
|
||
return f"− {args.get('item', '?')}"
|
||
elif tool_name == "replace_gear":
|
||
return f"↻ {args.get('before', '?')} → {args.get('after', '?')}"
|
||
elif tool_name == "add_note":
|
||
note = args.get("note", "?")
|
||
return f"📝 {note[:60]}{'…' if len(note) > 60 else ''}"
|
||
elif tool_name == "replace_note":
|
||
return f"📝 {args.get('before', '?')[:40]} → {args.get('after', '?')[:40]}"
|
||
elif tool_name == "world_update":
|
||
return "🌍 World updated"
|
||
elif tool_name == "journal_update":
|
||
parts = []
|
||
for a in args.get("add", []):
|
||
parts.append(f"📋 {a}")
|
||
for d in args.get("done", []):
|
||
parts.append(f"✅ {d}")
|
||
return "; ".join(parts) if parts else ""
|
||
return ""
|
||
|
||
|
||
def extract_tool_calls(text: str) -> list[dict]:
|
||
"""Extract tool calls from ```tool blocks in LLM response."""
|
||
calls = []
|
||
seen = set()
|
||
|
||
for m in re.finditer(r"```tool\s*\n?", text):
|
||
try:
|
||
decoder = json.JSONDecoder()
|
||
obj, end = decoder.raw_decode(text, m.end())
|
||
except (json.JSONDecodeError, ValueError, StopIteration):
|
||
close = text.find("```", m.end())
|
||
if close > 0:
|
||
raw = text[m.end():close].strip()
|
||
raw = re.sub(r'"(?:[^"\\]|\\.)*"', lambda x: x.group(0).replace("\n", "\\n"), raw, flags=re.DOTALL)
|
||
try:
|
||
obj = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
else:
|
||
continue
|
||
|
||
if not isinstance(obj, dict) or "tool" not in obj:
|
||
continue
|
||
|
||
key = (obj["tool"], json.dumps(obj.get("args", {}), sort_keys=True))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
calls.append(obj)
|
||
|
||
return calls
|