diff --git a/src/bot.py b/src/bot.py
index cf03b85..b8b322d 100644
--- a/src/bot.py
+++ b/src/bot.py
@@ -1,4 +1,5 @@
import html
+import json
import re
import time
import urllib.parse
@@ -30,8 +31,28 @@ channels_dir = ensure_dir(data_dir / "channels")
validating = set()
# State tracking for tree hole conversations
-# user_id -> {"action": "treehole"|"reply", "channel": str, "sender_id": int (for reply)}
-user_states: dict[int, dict] = {}
+# user_id -> {"action": "treehole"|"reply"|"leaf", ...}
+states_file = data_dir / "user_states.json"
+
+
+def load_states() -> dict[int, dict]:
+ if states_file.exists():
+ raw = json.loads(states_file.read_text(encoding='utf-8'))
+ return {int(k): v for k, v in raw.items()}
+ return {}
+
+
+def set_state(user_id: int, state: dict | None):
+ """Set or clear the state for a user, and persist to disk."""
+ if state is None:
+ user_states.pop(user_id, None)
+ else:
+ user_states[user_id] = state
+ states_file.write_text(json.dumps({str(k): v for k, v in user_states.items()},
+ ensure_ascii=False), encoding='utf-8')
+
+
+user_states: dict[int, dict] = load_states()
# Rate limiting for tree hole messages: user_id -> last send timestamp
treehole_rate_limit: dict[int, float] = {}
@@ -157,7 +178,7 @@ async def handle_leaf(update: Update, user_id: int, parent: str):
if not db.channel_info(parent):
return await update.message.reply_text("上级频道还不在树上... 是不是打错了 qwq")
- user_states[user_id] = {"action": "leaf", "parent": parent}
+ set_state(user_id, {"action": "leaf", "parent": parent})
await update.message.reply_html(f"🌿 成为树叶\n\n你想让哪个频道成为 @{parent} 的树叶呢?(请发送你的频道的 @用户名)")
@@ -205,7 +226,7 @@ async def handle_treehole(update: Update, user_id: int, channel: str):
remaining = int(TREEHOLE_COOLDOWN - (time.time() - last_time))
return await update.message.reply_text(f"发送太频繁了,请 {remaining} 秒后再试~")
- user_states[user_id] = {"action": "treehole", "channel": channel}
+ set_state(user_id, {"action": "treehole", "channel": channel})
return await update.message.reply_html(f"🕳️ 树洞模式\n\n想对频道 @{channel} 的主人说什么呢?"
f"(发送文字消息即可,消息将会匿名发送)")
@@ -229,7 +250,7 @@ async def reply_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
if actual_owner != owner_id:
return await query.answer("只有频道主人才能回复哦~", show_alert=False)
- user_states[owner_id] = {"action": "reply", "sender_id": sender_id, "channel": channel}
+ set_state(owner_id, {"action": "reply", "sender_id": sender_id, "channel": channel})
await query.answer()
return await context.bot.send_message(
chat_id=owner_id,
@@ -271,7 +292,8 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
state = user_states.get(user_id)
if not state:
- return
+ return await update.message.reply_text("🌳 不太明白你在说什么哦~ 请点击频道消息上的按钮来互动吧!\n"
+ "(如果哪里不对的话,应该是 bot 重启了,重新点击按钮就好 ;-;)")
if state["action"] == "leaf":
parent = state["parent"]
@@ -296,7 +318,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
"以及需要输入频道的 @用户名,不是显示名哦~")
# Success paths — clear state
- del user_states[user_id]
+ set_state(user_id, None)
info = utils.extract_meta_tags(text)
if sha in text:
@@ -316,13 +338,13 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
不过上树之前,为了防止被滥用,需要先验证一下你是 {channel} 的管理员...
-请编辑频道简介加入验证码 {sha} 再点击下面的「添加好了」吧~(加在哪里都可以的 > < 验证完就可以删掉)
-""".strip(), reply_markup=verify_btn)
+请编辑频道简介加入验证码 {sha} 再点击下面的「添加好了」吧~(加在哪里都可以的 > < 验证完就可以删掉)
+""".strip(), reply_markup=verify_btn, parse_mode="HTML")
validating.add(sha)
elif state["action"] == "treehole":
channel = state["channel"]
- del user_states[user_id]
+ set_state(user_id, None)
# Update rate limit
treehole_rate_limit[user_id] = time.time()
@@ -376,7 +398,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
elif state["action"] == "reply":
sender_id = state["sender_id"]
channel = state["channel"]
- del user_states[user_id]
+ set_state(user_id, None)
try:
escaped_text = html.escape(update.message.text)
@@ -435,6 +457,10 @@ async def verify_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
text = channel_html(channel)
if sha in text:
+ # Check if already registered (e.g. double-click)
+ if db.channel_info(channel):
+ return await query.answer("这个频道已经在树上了哦~", show_alert=False)
+
info = utils.extract_meta_tags(text)
title = info.title or channel
logger.info(f"> 🌿 Registering channel {channel} with parent {parent}.")
@@ -482,6 +508,24 @@ layout_html = (Path(__file__).parent.parent / "public" / "layout.html").read_tex
fmt_html = lambda x: layout_html.replace("{{CONTENT}}", x).replace("\n", "")
+def tree_to_dict(channel: str) -> dict | None:
+ """Recursively build a dict representation of the tree for the API."""
+ info = db.channel_info(channel)
+ if not info:
+ return None
+ return {
+ "username": channel,
+ "name": info.name,
+ "water": db.get_votes(channel),
+ "children": [tree_to_dict(c.username) for c in info.children],
+ }
+
+
+@app.get("/api/tree")
+def api_tree():
+ return tree_to_dict("azaneko")
+
+
@app.get("/c/{channel}", response_class=HTMLResponse)
def channel_info(channel: str):
info = db.channel_info(channel)
@@ -493,6 +537,8 @@ def channel_info(channel: str):
""")
leaf_txt = '树枝' if info.children else '树叶'
+ votes = db.get_votes(channel)
+ water_html = f'
💧 这个频道已经被浇了 {votes} 次水~
' if votes else '' return fmt_html(f"""下面这些是这个频道的树枝: