1
0
mirror of https://github.com/eternnoir/pyTelegramBotAPI.git synced 2023-08-10 21:12:57 +03:00
pyTelegramBotAPI/telebot/util.py

486 lines
14 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
2018-01-15 16:08:50 +03:00
import random
2018-08-17 12:54:26 +03:00
import re
2018-01-15 16:08:50 +03:00
import string
import threading
2017-06-20 15:45:18 +03:00
import traceback
from typing import Any, Callable, List, Dict, Optional, Union
2018-08-17 12:54:26 +03:00
2021-08-18 22:16:30 +03:00
# noinspection PyPep8Naming
import queue as Queue
2018-03-10 09:41:34 +03:00
import logging
from telebot import types
try:
2021-08-18 22:16:30 +03:00
# noinspection PyPackageRequirements
from PIL import Image
from io import BytesIO
pil_imported = True
except:
pil_imported = False
MAX_MESSAGE_LENGTH = 4096
2018-03-10 09:41:34 +03:00
logger = logging.getLogger('TeleBot')
2017-07-19 01:35:19 +03:00
thread_local = threading.local()
content_type_media = [
2021-09-02 20:46:01 +03:00
'text', 'audio', 'animation', 'document', 'photo', 'sticker', 'video', 'video_note', 'voice', 'contact', 'dice', 'poll',
'venue', 'location'
]
content_type_service = [
'new_chat_members', 'left_chat_member', 'new_chat_title', 'new_chat_photo', 'delete_chat_photo', 'group_chat_created',
'supergroup_chat_created', 'channel_chat_created', 'migrate_to_chat_id', 'migrate_from_chat_id', 'pinned_message',
'proximity_alert_triggered', 'voice_chat_scheduled', 'voice_chat_started', 'voice_chat_ended',
'voice_chat_participants_invited', 'message_auto_delete_timer_changed'
]
2017-07-19 01:35:19 +03:00
update_types = [
"update_id", "message", "edited_message", "channel_post", "edited_channel_post", "inline_query",
"chosen_inline_result", "callback_query", "shipping_query", "pre_checkout_query", "poll", "poll_answer",
"my_chat_member", "chat_member"
]
class WorkerThread(threading.Thread):
2020-07-04 20:45:48 +03:00
count = 0
def __init__(self, exception_callback=None, queue=None, name=None):
if not name:
name = "WorkerThread{0}".format(self.__class__.count + 1)
self.__class__.count += 1
if not queue:
queue = Queue.Queue()
threading.Thread.__init__(self, name=name)
self.queue = queue
self.daemon = True
self.received_task_event = threading.Event()
self.done_event = threading.Event()
self.exception_event = threading.Event()
self.continue_event = threading.Event()
self.exception_callback = exception_callback
self.exception_info = None
2020-07-04 20:45:48 +03:00
self._running = True
self.start()
def run(self):
while self._running:
try:
task, args, kwargs = self.queue.get(block=True, timeout=.5)
self.continue_event.clear()
self.received_task_event.clear()
self.done_event.clear()
self.exception_event.clear()
logger.debug("Received task")
self.received_task_event.set()
task(*args, **kwargs)
logger.debug("Task complete")
self.done_event.set()
except Queue.Empty:
pass
except Exception as e:
2020-11-07 14:43:17 +03:00
logger.debug(type(e).__name__ + " occurred, args=" + str(e.args) + "\n" + traceback.format_exc())
self.exception_info = e
2020-07-04 20:45:48 +03:00
self.exception_event.set()
if self.exception_callback:
self.exception_callback(self, self.exception_info)
2020-07-04 20:45:48 +03:00
self.continue_event.wait()
def put(self, task, *args, **kwargs):
self.queue.put((task, args, kwargs))
def raise_exceptions(self):
if self.exception_event.is_set():
raise self.exception_info
2020-07-04 20:45:48 +03:00
def clear_exceptions(self):
self.exception_event.clear()
self.continue_event.set()
def stop(self):
self._running = False
class ThreadPool:
def __init__(self, num_threads=2):
self.tasks = Queue.Queue()
self.workers = [WorkerThread(self.on_exception, self.tasks) for _ in range(num_threads)]
self.num_threads = num_threads
self.exception_event = threading.Event()
self.exception_info = None
def put(self, func, *args, **kwargs):
self.tasks.put((func, args, kwargs))
def on_exception(self, worker_thread, exc_info):
self.exception_info = exc_info
self.exception_event.set()
worker_thread.continue_event.set()
def raise_exceptions(self):
if self.exception_event.is_set():
raise self.exception_info
def clear_exceptions(self):
self.exception_event.clear()
def close(self):
for worker in self.workers:
worker.stop()
for worker in self.workers:
worker.join()
2015-10-01 12:33:23 +03:00
class AsyncTask:
def __init__(self, target, *args, **kwargs):
self.target = target
self.args = args
self.kwargs = kwargs
self.done = False
self.thread = threading.Thread(target=self._run)
self.thread.start()
def _run(self):
try:
self.result = self.target(*self.args, **self.kwargs)
except Exception as e:
self.result = e
self.done = True
def wait(self):
if not self.done:
self.thread.join()
if isinstance(self.result, BaseException):
raise self.result
else:
return self.result
2018-07-02 18:13:11 +03:00
def async_dec():
def decorator(fn):
def wrapper(*args, **kwargs):
return AsyncTask(fn, *args, **kwargs)
return wrapper
return decorator
def is_string(var):
return isinstance(var, str)
2020-07-31 23:30:38 +03:00
def is_dict(var):
return isinstance(var, dict)
2020-07-31 23:30:38 +03:00
def is_bytes(var):
return isinstance(var, bytes)
def is_pil_image(var):
return pil_imported and isinstance(var, Image.Image)
def pil_image_to_file(image, extension='JPEG', quality='web_low'):
if pil_imported:
photoBuffer = BytesIO()
image.convert('RGB').save(photoBuffer, extension, quality=quality)
photoBuffer.seek(0)
return photoBuffer
else:
raise RuntimeError('PIL module is not imported')
def is_command(text: str) -> bool:
"""
Checks if `text` is a command. Telegram chat commands start with the '/' character.
:param text: Text to check.
:return: True if `text` is a command, else False.
"""
if text is None: return False
return text.startswith('/')
def extract_command(text: str) -> Union[str, None]:
"""
Extracts the command from `text` (minus the '/') if `text` is a command (see is_command).
If `text` is not a command, this function returns None.
Examples:
extract_command('/help'): 'help'
extract_command('/help@BotName'): 'help'
extract_command('/search black eyed peas'): 'search'
extract_command('Good day to you'): None
:param text: String to extract the command from
:return: the command if `text` is a command (according to is_command), else None.
"""
if text is None: return None
return text.split()[0].split('@')[0][1:] if is_command(text) else None
def extract_arguments(text: str) -> str:
2021-06-03 20:06:53 +03:00
"""
Returns the argument after the command.
Examples:
extract_arguments("/get name"): 'name'
extract_arguments("/get"): ''
extract_arguments("/get@botName name"): 'name'
2021-06-03 20:06:53 +03:00
:param text: String to extract the arguments from a command
:return: the arguments if `text` is a command (according to is_command), else None.
2021-06-03 20:06:53 +03:00
"""
regexp = re.compile(r"/\w*(@\w*)*\s*([\s\S]*)", re.IGNORECASE)
result = regexp.match(text)
return result.group(2) if is_command(text) else None
2021-06-03 20:06:53 +03:00
def split_string(text: str, chars_per_string: int) -> List[str]:
"""
Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string.
This is very useful for splitting one giant message into multiples.
:param text: The text to split
:param chars_per_string: The number of characters per line the text is split into.
:return: The splitted text as a list of strings.
"""
return [text[i:i + chars_per_string] for i in range(0, len(text), chars_per_string)]
def smart_split(text: str, chars_per_string: int=MAX_MESSAGE_LENGTH) -> List[str]:
"""
Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string.
This is very useful for splitting one giant message into multiples.
If `chars_per_string` > 4096: `chars_per_string` = 4096.
Splits by '\n', '. ' or ' ' in exactly this priority.
:param text: The text to split
:param chars_per_string: The number of maximum characters per part the text is split to.
:return: The splitted text as a list of strings.
"""
def _text_before_last(substr: str) -> str:
return substr.join(part.split(substr)[:-1]) + substr
if chars_per_string > MAX_MESSAGE_LENGTH: chars_per_string = MAX_MESSAGE_LENGTH
parts = []
while True:
if len(text) < chars_per_string:
parts.append(text)
return parts
part = text[:chars_per_string]
if "\n" in part: part = _text_before_last("\n")
elif ". " in part: part = _text_before_last(". ")
elif " " in part: part = _text_before_last(" ")
parts.append(part)
text = text[len(part):]
def escape(text: str) -> str:
"""
Replaces the following chars in `text` ('&' with '&amp;', '<' with '&lt;' and '>' with '&gt;').
:param text: the text to escape
:return: the escaped text
"""
chars = {"&": "&amp;", "<": "&lt;", ">": "&gt"}
for old, new in chars.items(): text = text.replace(old, new)
return text
def user_link(user: types.User, include_id: bool=False) -> str:
"""
Returns an HTML user link. This is useful for reports.
Attention: Don't forget to set parse_mode to 'HTML'!
Example:
bot.send_message(your_user_id, user_link(message.from_user) + ' started the bot!', parse_mode='HTML')
:param user: the user (not the user_id)
:param include_id: include the user_id
:return: HTML user link
"""
name = escape(user.first_name)
return (f"<a href='tg://user?id={user.id}'>{name}</a>"
+ (f" (<pre>{user.id}</pre>)" if include_id else ""))
def quick_markup(values: Dict[str, Dict[str, Any]], row_width: int=2) -> types.InlineKeyboardMarkup:
"""
Returns a reply markup from a dict in this format: {'text': kwargs}
This is useful to avoid always typing 'btn1 = InlineKeyboardButton(...)' 'btn2 = InlineKeyboardButton(...)'
Example:
quick_markup({
'Twitter': {'url': 'https://twitter.com'},
'Facebook': {'url': 'https://facebook.com'},
'Back': {'callback_data': 'whatever'}
}, row_width=2):
returns an InlineKeyboardMarkup with two buttons in a row, one leading to Twitter, the other to facebook
and a back button below
kwargs can be:
{
'url': None,
'callback_data': None,
'switch_inline_query': None,
'switch_inline_query_current_chat': None,
'callback_game': None,
'pay': None,
'login_url': None
}
:param values: a dict containing all buttons to create in this format: {text: kwargs} {str:}
:param row_width: int row width
:return: InlineKeyboardMarkup
"""
markup = types.InlineKeyboardMarkup(row_width=row_width)
2021-07-19 20:01:37 +03:00
buttons = [
types.InlineKeyboardButton(text=text, **kwargs)
for text, kwargs in values.items()
]
markup.add(*buttons)
return markup
# CREDITS TO http://stackoverflow.com/questions/12317940#answer-12320352
def or_set(self):
self._set()
self.changed()
def or_clear(self):
self._clear()
self.changed()
def orify(e, changed_callback):
2021-01-14 03:44:37 +03:00
if not hasattr(e, "_set"):
e._set = e.set
if not hasattr(e, "_clear"):
e._clear = e.clear
e.changed = changed_callback
e.set = lambda: or_set(e)
e.clear = lambda: or_clear(e)
def OrEvent(*events):
or_event = threading.Event()
def changed():
2021-01-14 03:44:37 +03:00
bools = [ev.is_set() for ev in events]
if any(bools):
or_event.set()
else:
or_event.clear()
2015-10-03 13:48:56 +03:00
def busy_wait():
while not or_event.is_set():
or_event._wait(3)
for e in events:
orify(e, changed)
2015-10-03 13:48:56 +03:00
or_event._wait = or_event.wait
or_event.wait = busy_wait
changed()
return or_event
2017-07-19 01:35:19 +03:00
def per_thread(key, construct_value, reset=False):
if reset or not hasattr(thread_local, key):
2017-07-19 01:35:19 +03:00
value = construct_value()
setattr(thread_local, key, value)
return getattr(thread_local, key)
2018-01-15 16:08:50 +03:00
2020-07-31 23:30:38 +03:00
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
# https://stackoverflow.com/a/312464/9935473
for i in range(0, len(lst), n):
yield lst[i:i + n]
2018-01-15 16:08:50 +03:00
2018-01-15 16:08:50 +03:00
def generate_random_token():
return ''.join(random.sample(string.ascii_letters, 16))
2021-06-30 14:47:39 +03:00
def deprecated(warn: bool=False, alternative: Optional[Callable]=None):
"""
Use this decorator to mark functions as deprecated.
When the function is used, an info (or warning if `warn` is True) is logged.
:param warn: If True a warning is logged else an info
:param alternative: The new function to use instead
"""
def decorator(function):
def wrapper(*args, **kwargs):
if not warn:
logger.info(f"`{function.__name__}` is deprecated."
+ (f" Use `{alternative.__name__}` instead" if alternative else ""))
else:
logger.warn(f"`{function.__name__}` is deprecated."
+ (f" Use `{alternative.__name__}` instead" if alternative else ""))
return function(*args, **kwargs)
return wrapper
return decorator
2021-08-24 14:01:10 +03:00
# Cloud helpers
2021-08-25 15:17:25 +03:00
def webhook_google_functions(bot, request):
2021-08-24 14:01:10 +03:00
"""A webhook endpoint for Google Cloud Functions FaaS."""
if request.is_json:
try:
request_json = request.get_json()
update = types.Update.de_json(request_json)
bot.process_new_updates([update])
return ''
except Exception as e:
print(e)
return 'Bot FAIL', 400
else:
return 'Bot ON'
2021-09-11 21:02:56 +03:00
class SimpleCustomFilter:
"""
Simple Custom Filter base class.
Create child class with check() method.
Accepts only bool.
"""
def check(message):
"""
Perform a check.
"""
pass
class AdvancedCustomFilter:
"""
Simple Custom Filter base class.
Create child class with check() method.
Can accept to parameters.
message: Message class
text: Filter value given in handler
"""
def check(message, text):
"""
Perform a check.
"""
pass