-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
Copy pathutils.py
434 lines (360 loc) · 15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
from __future__ import annotations
import asyncio
import itertools
import json
import logging
import os
import base64
import io
from PIL import Image, ImageChops
import telegram
from telegram import Message, MessageEntity, Update, ChatMember, constants
from telegram.ext import CallbackContext, ContextTypes
from usage_tracker import UsageTracker
def message_text(message: Message) -> str:
"""
Returns the text of a message, excluding any bot commands.
"""
message_txt = message.text
if message_txt is None:
return ''
for _, text in sorted(message.parse_entities([MessageEntity.BOT_COMMAND]).items(),
key=(lambda item: item[0].offset)):
message_txt = message_txt.replace(text, '').strip()
return message_txt if len(message_txt) > 0 else ''
async def is_user_in_group(update: Update, context: CallbackContext, user_id: int) -> bool:
"""
Checks if user_id is a member of the group
"""
try:
chat_member = await context.bot.get_chat_member(update.message.chat_id, user_id)
return chat_member.status in [ChatMember.OWNER, ChatMember.ADMINISTRATOR, ChatMember.MEMBER]
except telegram.error.BadRequest as e:
if str(e) == "User not found":
return False
else:
raise e
except Exception as e:
raise e
def get_thread_id(update: Update) -> int | None:
"""
Gets the message thread id for the update, if any
"""
if update.effective_message and update.effective_message.is_topic_message:
return update.effective_message.message_thread_id
return None
def get_stream_cutoff_values(update: Update, content: str) -> int:
"""
Gets the stream cutoff values for the message length
"""
if is_group_chat(update):
# group chats have stricter flood limits
return 180 if len(content) > 1000 else 120 if len(content) > 200 \
else 90 if len(content) > 50 else 50
return 90 if len(content) > 1000 else 45 if len(content) > 200 \
else 25 if len(content) > 50 else 15
def is_group_chat(update: Update) -> bool:
"""
Checks if the message was sent from a group chat
"""
if not update.effective_chat:
return False
return update.effective_chat.type in [
constants.ChatType.GROUP,
constants.ChatType.SUPERGROUP
]
def split_into_chunks(text: str, chunk_size: int = 4096) -> list[str]:
"""
Splits a string into chunks of a given size.
"""
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
async def wrap_with_indicator(update: Update, context: CallbackContext, coroutine,
chat_action: constants.ChatAction = "", is_inline=False):
"""
Wraps a coroutine while repeatedly sending a chat action to the user.
"""
task = context.application.create_task(coroutine(), update=update)
while not task.done():
if not is_inline:
context.application.create_task(
update.effective_chat.send_action(chat_action, message_thread_id=get_thread_id(update))
)
try:
await asyncio.wait_for(asyncio.shield(task), 4.5)
except asyncio.TimeoutError:
pass
async def edit_message_with_retry(context: ContextTypes.DEFAULT_TYPE, chat_id: int | None,
message_id: str, text: str, markdown: bool = True, is_inline: bool = False):
"""
Edit a message with retry logic in case of failure (e.g. broken markdown)
:param context: The context to use
:param chat_id: The chat id to edit the message in
:param message_id: The message id to edit
:param text: The text to edit the message with
:param markdown: Whether to use markdown parse mode
:param is_inline: Whether the message to edit is an inline message
:return: None
"""
try:
await context.bot.edit_message_text(
chat_id=chat_id,
message_id=int(message_id) if not is_inline else None,
inline_message_id=message_id if is_inline else None,
text=text,
parse_mode=constants.ParseMode.MARKDOWN if markdown else None,
)
except telegram.error.BadRequest as e:
if str(e).startswith("Message is not modified"):
return
try:
await context.bot.edit_message_text(
chat_id=chat_id,
message_id=int(message_id) if not is_inline else None,
inline_message_id=message_id if is_inline else None,
text=text,
)
except Exception as e:
logging.warning(f'Failed to edit message: {str(e)}')
raise e
except Exception as e:
logging.warning(str(e))
raise e
async def error_handler(_: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""
Handles errors in the telegram-python-bot library.
"""
logging.error(f'Exception while handling an update: {context.error}')
async def is_allowed(config, update: Update, context: CallbackContext, is_inline=False) -> bool:
"""
Checks if the user is allowed to use the bot.
"""
if config['allowed_user_ids'] == '*':
return True
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
if is_admin(config, user_id):
return True
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
allowed_user_ids = config['allowed_user_ids'].split(',')
# Check if user is allowed
if str(user_id) in allowed_user_ids:
return True
# Check if it's a group a chat with at least one authorized member
if not is_inline and is_group_chat(update):
admin_user_ids = config['admin_user_ids'].split(',')
for user in itertools.chain(allowed_user_ids, admin_user_ids):
if not user.strip():
continue
if await is_user_in_group(update, context, user):
logging.info(f'{user} is a member. Allowing group chat message...')
return True
logging.info(f'Group chat messages from user {name} '
f'(id: {user_id}) are not allowed')
return False
def is_admin(config, user_id: int, log_no_admin=False) -> bool:
"""
Checks if the user is the admin of the bot.
The first user in the user list is the admin.
"""
if config['admin_user_ids'] == '-':
if log_no_admin:
logging.info('No admin user defined.')
return False
admin_user_ids = config['admin_user_ids'].split(',')
# Check if user is in the admin user list
if str(user_id) in admin_user_ids:
return True
return False
def get_user_budget(config, user_id) -> float | None:
"""
Get the user's budget based on their user ID and the bot configuration.
:param config: The bot configuration object
:param user_id: User id
:return: The user's budget as a float, or None if the user is not found in the allowed user list
"""
# no budget restrictions for admins and '*'-budget lists
if is_admin(config, user_id) or config['user_budgets'] == '*':
return float('inf')
user_budgets = config['user_budgets'].split(',')
if config['allowed_user_ids'] == '*':
# same budget for all users, use value in first position of budget list
if len(user_budgets) > 1:
logging.warning('multiple values for budgets set with unrestricted user list '
'only the first value is used as budget for everyone.')
return float(user_budgets[0])
allowed_user_ids = config['allowed_user_ids'].split(',')
if str(user_id) in allowed_user_ids:
user_index = allowed_user_ids.index(str(user_id))
if len(user_budgets) <= user_index:
logging.warning(f'No budget set for user id: {user_id}. Budget list shorter than user list.')
return 0.0
return float(user_budgets[user_index])
return None
def get_remaining_budget(config, usage, update: Update, is_inline=False) -> float:
"""
Calculate the remaining budget for a user based on their current usage.
:param config: The bot configuration object
:param usage: The usage tracker object
:param update: Telegram update object
:param is_inline: Boolean flag for inline queries
:return: The remaining budget for the user as a float
"""
# Mapping of budget period to cost period
budget_cost_map = {
"monthly": "cost_month",
"daily": "cost_today",
"all-time": "cost_all_time"
}
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
if user_id not in usage:
usage[user_id] = UsageTracker(user_id, name)
# Get budget for users
user_budget = get_user_budget(config, user_id)
budget_period = config['budget_period']
if user_budget is not None:
cost = usage[user_id].get_current_cost()[budget_cost_map[budget_period]]
return user_budget - cost
# Get budget for guests
if 'guests' not in usage:
usage['guests'] = UsageTracker('guests', 'all guest users in group chats')
cost = usage['guests'].get_current_cost()[budget_cost_map[budget_period]]
return config['guest_budget'] - cost
def is_within_budget(config, usage, update: Update, is_inline=False) -> bool:
"""
Checks if the user reached their usage limit.
Initializes UsageTracker for user and guest when needed.
:param config: The bot configuration object
:param usage: The usage tracker object
:param update: Telegram update object
:param is_inline: Boolean flag for inline queries
:return: Boolean indicating if the user has a positive budget
"""
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
if user_id not in usage:
usage[user_id] = UsageTracker(user_id, name)
remaining_budget = get_remaining_budget(config, usage, update, is_inline=is_inline)
return remaining_budget > 0
def add_chat_request_to_usage_tracker(usage, config, user_id, used_tokens):
"""
Add chat request to usage tracker
:param usage: The usage tracker object
:param config: The bot configuration object
:param user_id: The user id
:param used_tokens: The number of tokens used
"""
try:
if int(used_tokens) == 0:
logging.warning('No tokens used. Not adding chat request to usage tracker.')
return
# add chat request to users usage tracker
usage[user_id].add_chat_tokens(used_tokens, config['token_price'])
# add guest chat request to guest usage tracker
allowed_user_ids = config['allowed_user_ids'].split(',')
if str(user_id) not in allowed_user_ids and 'guests' in usage:
usage["guests"].add_chat_tokens(used_tokens, config['token_price'])
except Exception as e:
logging.warning(f'Failed to add tokens to usage_logs: {str(e)}')
pass
def get_reply_to_message_id(config, update: Update):
"""
Returns the message id of the message to reply to
:param config: Bot configuration object
:param update: Telegram update object
:return: Message id of the message to reply to, or None if quoting is disabled
"""
if config['enable_quoting'] or is_group_chat(update):
return update.message.message_id
return None
def is_direct_result(response: any) -> bool:
"""
Checks if the dict contains a direct result that can be sent directly to the user
:param response: The response value
:return: Boolean indicating if the result is a direct result
"""
if type(response) is not dict:
try:
json_response = json.loads(response)
return json_response.get('direct_result', False)
except:
return False
else:
return response.get('direct_result', False)
async def handle_direct_result(config, update: Update, response: any):
"""
Handles a direct result from a plugin
"""
if type(response) is not dict:
response = json.loads(response)
result = response['direct_result']
kind = result['kind']
format = result['format']
value = result['value']
common_args = {
'message_thread_id': get_thread_id(update),
'reply_to_message_id': get_reply_to_message_id(config, update),
}
if kind == 'photo':
if format == 'url':
await update.effective_message.reply_photo(**common_args, photo=value)
elif format == 'path':
await update.effective_message.reply_photo(**common_args, photo=open(value, 'rb'))
elif kind == 'gif' or kind == 'file':
if format == 'url':
await update.effective_message.reply_document(**common_args, document=value)
if format == 'path':
await update.effective_message.reply_document(**common_args, document=open(value, 'rb'))
elif kind == 'dice':
await update.effective_message.reply_dice(**common_args, emoji=value)
if format == 'path':
cleanup_intermediate_files(response)
def cleanup_intermediate_files(response: any):
"""
Deletes intermediate files created by plugins
"""
if type(response) is not dict:
response = json.loads(response)
result = response['direct_result']
format = result['format']
value = result['value']
if format == 'path':
if os.path.exists(value):
os.remove(value)
# Function to encode the image
def encode_image(fileobj):
image = base64.b64encode(fileobj.getvalue()).decode('utf-8')
return f'data:image/jpeg;base64,{image}'
def decode_image(imgbase64):
image = imgbase64[len('data:image/jpeg;base64,'):]
return base64.b64decode(image)
def compute_image_diff(im1, im2):
im1 = Image.open(im1)
im2 = Image.open(im2)
if im1.size != im2.size:
raise ValueError("The image and the mask must be of the same size.")
pixels1 = im1.load()
pixels2 = im2.load()
def pixel_difference(pixel1, pixel2):
channel_diff = sum(tuple(abs(c1 - c2) for c1, c2 in zip(pixel1, pixel2)))
return channel_diff
transparent_box = im1.convert('RGBA')
xtop, xbottom = im1.size[0], 0
ytop, ybottom = im1.size[1], 0
threshold = 256
for y in range(im1.size[1]):
for x in range(im1.size[0]):
if pixel_difference(pixels1[x, y], pixels2[x, y]) > threshold:
xtop = min(xtop, x)
xbottom = max(xbottom, x)
ytop = min(ytop, y)
ybottom = max(ybottom, y)
if xbottom >= xtop and ybottom >= ytop:
for x in range(xtop, xbottom + 1):
for y in range(ytop, ybottom + 1):
transparent_box.putpixel((x, y), (255, 255, 255, 0))
else:
raise('No difference detected in the images')
res = io.BytesIO()
transparent_box.save(res, format='PNG')
res.seek(0)
return res