[F] Fix treehole security
This commit is contained in:
+29
-10
@@ -236,13 +236,20 @@ async def reply_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
query = update.callback_query
|
||||
data = query.data
|
||||
|
||||
# Format: reply:{sender_id}:{channel}
|
||||
parts = data.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
|
||||
# New format: reply:{msg_id}
|
||||
if len(parts) == 2:
|
||||
msg = db.get_treehole_msg(int(parts[1]))
|
||||
if not msg:
|
||||
return await query.answer("这条消息已经过期了哦~", show_alert=False)
|
||||
sender_id, channel = msg.sender_id, msg.channel_id
|
||||
# Old format (backward compat): reply:{sender_id}:{channel}
|
||||
elif len(parts) == 3:
|
||||
sender_id, channel = int(parts[1]), parts[2]
|
||||
else:
|
||||
return
|
||||
|
||||
_, sender_id_str, channel = parts
|
||||
sender_id = int(sender_id_str)
|
||||
owner_id = query.from_user.id
|
||||
|
||||
# Verify the person clicking is the channel owner
|
||||
@@ -263,13 +270,20 @@ async def block_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
query = update.callback_query
|
||||
data = query.data
|
||||
|
||||
# Format: block:{sender_id}:{channel}
|
||||
parts = data.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
|
||||
# New format: block:{msg_id}
|
||||
if len(parts) == 2:
|
||||
msg = db.get_treehole_msg(int(parts[1]))
|
||||
if not msg:
|
||||
return await query.answer("这条消息已经过期了哦~", show_alert=False)
|
||||
sender_id, channel = msg.sender_id, msg.channel_id
|
||||
# Old format (backward compat): block:{sender_id}:{channel}
|
||||
elif len(parts) == 3:
|
||||
sender_id, channel = int(parts[1]), parts[2]
|
||||
else:
|
||||
return
|
||||
|
||||
_, sender_id_str, channel = parts
|
||||
sender_id = int(sender_id_str)
|
||||
owner_id = query.from_user.id
|
||||
|
||||
# Verify the person clicking is the channel owner
|
||||
@@ -377,9 +391,10 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
pass
|
||||
|
||||
# Send to owner anonymously
|
||||
msg_id = db.create_treehole_msg(user_id, channel)
|
||||
reply_btn = InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("💬 回复", callback_data=f"reply:{user_id}:{channel}")],
|
||||
[InlineKeyboardButton("🚫 屏蔽发送者", callback_data=f"block:{user_id}:{channel}")],
|
||||
[InlineKeyboardButton("💬 回复", callback_data=f"reply:{msg_id}")],
|
||||
[InlineKeyboardButton("🚫 屏蔽发送者", callback_data=f"block:{msg_id}")],
|
||||
])
|
||||
|
||||
try:
|
||||
@@ -402,9 +417,13 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
|
||||
try:
|
||||
escaped_text = html.escape(update.message.text)
|
||||
reply_back_btn = InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("💬 回复", url=f"https://t.me/{BOT_NAME}?start=th_{channel}")],
|
||||
])
|
||||
await context.bot.send_message(
|
||||
chat_id=sender_id,
|
||||
text=f"💬 <b>频道 @{channel} 的主人回复了你的树洞消息:</b>\n\n{escaped_text}",
|
||||
reply_markup=reply_back_btn,
|
||||
parse_mode="HTML"
|
||||
)
|
||||
await update.message.reply_text("✅ 回复已发送~")
|
||||
|
||||
@@ -42,8 +42,14 @@ class OwnerPref(BaseModel):
|
||||
treehole_notified = BooleanField(default=False) # Whether the owner has been sent the intro notice
|
||||
|
||||
|
||||
class TreeholeMsg(BaseModel):
|
||||
"""Stores metadata for treehole messages so callback data doesn't expose user IDs."""
|
||||
sender_id = BigIntegerField() # The anonymous sender's Telegram user ID
|
||||
channel = ForeignKeyField(Channel, backref='treehole_msgs', on_delete='CASCADE', field='username')
|
||||
|
||||
|
||||
with db:
|
||||
db.create_tables([Channel, Vote, Block, OwnerPref])
|
||||
db.create_tables([Channel, Vote, Block, OwnerPref, TreeholeMsg])
|
||||
|
||||
|
||||
def channel_info(username: str) -> Channel | None:
|
||||
@@ -116,6 +122,20 @@ def is_blocked(user_id: int, channel_username: str) -> bool:
|
||||
).exists()
|
||||
|
||||
|
||||
def create_treehole_msg(sender_id: int, channel_username: str) -> int:
|
||||
"""Create a treehole message record and return its numeric ID."""
|
||||
msg = TreeholeMsg.create(sender_id=sender_id, channel=channel_username)
|
||||
return msg.id
|
||||
|
||||
|
||||
def get_treehole_msg(msg_id: int) -> TreeholeMsg | None:
|
||||
"""Look up a treehole message by its ID."""
|
||||
try:
|
||||
return TreeholeMsg.get_by_id(msg_id)
|
||||
except TreeholeMsg.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def is_treehole_opted_out(owner_id: int) -> bool:
|
||||
"""Check if an owner has opted out of receiving tree hole messages."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user