From 6770011dd77ab5e64255bd842034e1de180e02bb Mon Sep 17 00:00:00 2001 From: _run Date: Sat, 27 Nov 2021 19:04:03 +0500 Subject: [PATCH] Middleware support --- telebot/__init__.py | 258 ++++++++++++---------------- telebot/asyncio_handler_backends.py | 156 ++--------------- telebot/asyncio_helper.py | 31 +--- 3 files changed, 138 insertions(+), 307 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 68c2e48..89689eb 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -13,6 +13,9 @@ from typing import Any, Callable, List, Optional, Union import telebot.util import telebot.types + +from inspect import signature + logger = logging.getLogger('TeleBot') formatter = logging.Formatter( @@ -69,6 +72,30 @@ class ExceptionHandler: return False +class SkipHandler: + """ + Class for skipping handlers. + Just return instance of this class + in middleware to skip handler. + Update will go to post_process, + but will skip execution of handler. + """ + + def __init__(self) -> None: + pass + +class CancelUpdate: + """ + Class for canceling updates. + Just return instance of this class + in middleware to skip update. + Update will skip handler and execution + of post_process in middlewares. + """ + + def __init__(self) -> None: + pass + class TeleBot: """ This is TeleBot Class Methods: @@ -3351,33 +3378,16 @@ class AsyncTeleBot: self.current_states = asyncio_handler_backends.StateMemory() - if asyncio_helper.ENABLE_MIDDLEWARE: - self.typed_middleware_handlers = { - 'message': [], - 'edited_message': [], - 'channel_post': [], - 'edited_channel_post': [], - 'inline_query': [], - 'chosen_inline_result': [], - 'callback_query': [], - 'shipping_query': [], - 'pre_checkout_query': [], - 'poll': [], - 'poll_answer': [], - 'my_chat_member': [], - 'chat_member': [], - 'chat_join_request': [] - } - self.default_middleware_handlers = [] + self.middlewares = [] async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None, - timeout: Optional[int]=None, allowed_updates: Optional[List]=None) -> types.Update: - json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates) + timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> types.Update: + json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout) return [types.Update.de_json(ju) for ju in json_updates] async def polling(self, non_stop: bool=False, skip_pending=False, interval: int=0, timeout: int=20, - long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None, + request_timeout: int=20, allowed_updates: Optional[List[str]]=None, none_stop: Optional[bool]=None): """ This allows the bot to retrieve Updates automatically and notify listeners and message handlers accordingly. @@ -3389,7 +3399,7 @@ class AsyncTeleBot: :param non_stop: Do not stop polling when an ApiException occurs. :param timeout: Request connection timeout :param skip_pending: skip old updates - :param long_polling_timeout: Timeout in seconds for long polling (see API docs) + :param request_timeout: Timeout in seconds for a request. :param allowed_updates: A list of the update types you want your bot to receive. For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types. See util.update_types for a complete list of available update types. @@ -3406,9 +3416,9 @@ class AsyncTeleBot: if skip_pending: await self.skip_updates() - await self._process_polling(non_stop, interval, timeout, long_polling_timeout, allowed_updates) + await self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates) - async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, long_polling_timeout: int=20, logger_level=logging.ERROR, + async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, request_timeout: int=20, logger_level=logging.ERROR, allowed_updates: Optional[List[str]]=None, *args, **kwargs): """ Wrap polling with infinite loop and exception handling to avoid bot stops polling. @@ -3432,7 +3442,7 @@ class AsyncTeleBot: self._polling = True while self._polling: try: - await self._process_polling(non_stop=True, timeout=timeout, long_polling_timeout=long_polling_timeout, + await self._process_polling(non_stop=True, timeout=timeout, request_timeout=request_timeout, allowed_updates=allowed_updates, *args, **kwargs) except Exception as e: if logger_level and logger_level >= logging.ERROR: @@ -3447,13 +3457,13 @@ class AsyncTeleBot: logger.error("Break infinity polling") async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20, - long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None): + request_timeout: int=20, allowed_updates: Optional[List[str]]=None): """ Function to process polling. :param non_stop: Do not stop polling when an ApiException occurs. :param interval: Delay between two update retrivals :param timeout: Request connection timeout - :param long_polling_timeout: Timeout in seconds for long polling (see API docs) + :param request_timeout: Timeout in seconds for long polling (see API docs) :param allowed_updates: A list of the update types you want your bot to receive. For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types. See util.update_types for a complete list of available update types. @@ -3471,49 +3481,74 @@ class AsyncTeleBot: while self._polling: try: - updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout) - - if updates: - logger.debug(f"Received {len(updates)} updates.") - - await self.process_new_updates(updates) - if interval: await asyncio.sleep(interval) - except KeyboardInterrupt: - logger.info("KeyboardInterrupt received.") - break + updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout) except asyncio.CancelledError: - break + return + except asyncio_helper.ApiTelegramException as e: - logger.info(str(e)) + logger.error(str(e)) continue except Exception as e: logger.error('Cause exception while getting updates.') - logger.error(str(e)) - await asyncio.sleep(3) - continue + if non_stop: + logger.error(str(e)) + await asyncio.sleep(3) + continue + else: + raise e + if updates: + self.offset = updates[-1].update_id + 1 + self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates + if interval: await asyncio.sleep(interval) finally: self._polling = False logger.warning('Polling is stopped.') - async def _loop_create_task(self, coro): + def _loop_create_task(self, coro): return asyncio.create_task(coro) - async def _process_updates(self, handlers, messages): + async def _process_updates(self, handlers, messages, update_type): + """ + Process updates. + :param handlers: + :param messages: + :return: + """ for message in messages: - for message_handler in handlers: - process_update = await self._test_message_handler(message_handler, message) - if not process_update: - continue - elif process_update: - try: - await self._loop_create_task(message_handler['function'](message)) - break - except Exception as e: - logger.error(str(e)) + middleware = await self.process_middlewares(message, update_type) + self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware)) + + async def _run_middlewares_and_handlers(self, handlers, message, middleware): + handler_error = None + data = {} + for message_handler in handlers: + process_update = await self._test_message_handler(message_handler, message) + if not process_update: + continue + elif process_update: + if middleware: + middleware_result = await middleware.pre_process(message, data) + if isinstance(middleware_result, SkipHandler): + await middleware.post_process(message, data, handler_error) + break + if isinstance(middleware_result, CancelUpdate): + return + try: + if "data" in signature(message_handler['function']).parameters: + await message_handler['function'](message, data) + else: + await message_handler['function'](message) + break + except Exception as e: + handler_error = e + logger.info(str(e)) + + if middleware: + await middleware.post_process(message, data, handler_error) # update handling async def process_new_updates(self, updates): upd_count = len(updates) @@ -3535,21 +3570,8 @@ class AsyncTeleBot: new_chat_members = None chat_join_request = None for update in updates: - if asyncio_helper.ENABLE_MIDDLEWARE: - try: - self.process_middlewares(update) - except Exception as e: - logger.error(str(e)) - if not self.suppress_middleware_excepions: - raise - else: - if update.update_id > self.offset: self.offset = update.update_id - continue logger.debug('Processing updates: {0}'.format(update)) - if update.update_id: - self.offset = update.update_id + 1 if update.message: - logger.info('Processing message') if new_messages is None: new_messages = [] new_messages.append(update.message) if update.edited_message: @@ -3620,69 +3642,55 @@ class AsyncTeleBot: await self.process_new_chat_member(new_chat_members) if chat_join_request: await self.process_chat_join_request(chat_join_request) - async def process_new_messages(self, new_messages): await self.__notify_update(new_messages) - await self._process_updates(self.message_handlers, new_messages) + await self._process_updates(self.message_handlers, new_messages, 'message') async def process_new_edited_messages(self, edited_message): - await self._process_updates(self.edited_message_handlers, edited_message) + await self._process_updates(self.edited_message_handlers, edited_message, 'edited_message') async def process_new_channel_posts(self, channel_post): - await self._process_updates(self.channel_post_handlers, channel_post) + await self._process_updates(self.channel_post_handlers, channel_post , 'channel_post') async def process_new_edited_channel_posts(self, edited_channel_post): - await self._process_updates(self.edited_channel_post_handlers, edited_channel_post) + await self._process_updates(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post') async def process_new_inline_query(self, new_inline_querys): - await self._process_updates(self.inline_handlers, new_inline_querys) + await self._process_updates(self.inline_handlers, new_inline_querys, 'inline_query') async def process_new_chosen_inline_query(self, new_chosen_inline_querys): - await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys) + await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query') async def process_new_callback_query(self, new_callback_querys): - await self._process_updates(self.callback_query_handlers, new_callback_querys) + await self._process_updates(self.callback_query_handlers, new_callback_querys, 'callback_query') async def process_new_shipping_query(self, new_shipping_querys): - await self._process_updates(self.shipping_query_handlers, new_shipping_querys) + await self._process_updates(self.shipping_query_handlers, new_shipping_querys, 'shipping_query') async def process_new_pre_checkout_query(self, pre_checkout_querys): - await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys) + await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query') async def process_new_poll(self, polls): - await self._process_updates(self.poll_handlers, polls) + await self._process_updates(self.poll_handlers, polls, 'poll') async def process_new_poll_answer(self, poll_answers): - await self._process_updates(self.poll_answer_handlers, poll_answers) + await self._process_updates(self.poll_answer_handlers, poll_answers, 'poll_answer') async def process_new_my_chat_member(self, my_chat_members): - await self._process_updates(self.my_chat_member_handlers, my_chat_members) + await self._process_updates(self.my_chat_member_handlers, my_chat_members, 'my_chat_member') async def process_new_chat_member(self, chat_members): - await self._process_updates(self.chat_member_handlers, chat_members) + await self._process_updates(self.chat_member_handlers, chat_members, 'chat_member') async def process_chat_join_request(self, chat_join_request): - await self._process_updates(self.chat_join_request_handlers, chat_join_request) - - async def process_middlewares(self, update): - for update_type, middlewares in self.typed_middleware_handlers.items(): - if getattr(update, update_type) is not None: - for typed_middleware_handler in middlewares: - try: - typed_middleware_handler(self, getattr(update, update_type)) - except Exception as e: - e.args = e.args + (f'Typed middleware handler "{typed_middleware_handler.__qualname__}"',) - raise - - if len(self.default_middleware_handlers) > 0: - for default_middleware_handler in self.default_middleware_handlers: - try: - default_middleware_handler(self, update) - except Exception as e: - e.args = e.args + (f'Default middleware handler "{default_middleware_handler.__qualname__}"',) - raise + await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request') + async def process_middlewares(self, update, update_type): + for middleware in self.middlewares: + if update_type in middleware.update_types: + return middleware + return None async def __notify_update(self, new_messages): if len(self.update_listener) == 0: @@ -3759,56 +3767,16 @@ class AsyncTeleBot: elif isinstance(filter_check, asyncio_filters.AdvancedCustomFilter): return await filter_check.check(message, filter_value) else: - logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter!") + logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.") return False - def middleware_handler(self, update_types=None): + def setup_middleware(self, middleware): """ - Middleware handler decorator. - - This decorator can be used to decorate functions that must be handled as middlewares before entering any other - message handlers - But, be careful and check type of the update inside the handler if more than one update_type is given - - Example: - - bot = TeleBot('TOKEN') - - # Print post message text before entering to any post_channel handlers - @bot.middleware_handler(update_types=['channel_post', 'edited_channel_post']) - def print_channel_post_text(bot_instance, channel_post): - print(channel_post.text) - - # Print update id before entering to any handlers - @bot.middleware_handler() - def print_channel_post_text(bot_instance, update): - print(update.update_id) - - :param update_types: Optional list of update types that can be passed into the middleware handler. - - """ - - def decorator(handler): - self.add_middleware_handler(handler, update_types) - return handler - - return decorator - - def add_middleware_handler(self, handler, update_types=None): - """ - Add middleware handler - :param handler: - :param update_types: + Setup middleware + :param middleware: :return: """ - if not asyncio_helper.ENABLE_MIDDLEWARE: - raise RuntimeError("Middleware is not enabled. Use asyncio_helper.ENABLE_MIDDLEWARE.") - - if update_types: - for update_type in update_types: - self.typed_middleware_handlers[update_type].append(handler) - else: - self.default_middleware_handlers.append(handler) + self.middlewares.append(middleware) def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs): """ diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 001f869..b46c988 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -1,146 +1,6 @@ import os import pickle -import threading -from telebot import apihelper - - -class HandlerBackend(object): - """ - Class for saving (next step|reply) handlers - """ - def __init__(self, handlers=None): - if handlers is None: - handlers = {} - self.handlers = handlers - - async def register_handler(self, handler_group_id, handler): - raise NotImplementedError() - - async def clear_handlers(self, handler_group_id): - raise NotImplementedError() - - async def get_handlers(self, handler_group_id): - raise NotImplementedError() - - -class MemoryHandlerBackend(HandlerBackend): - async def register_handler(self, handler_group_id, handler): - if handler_group_id in self.handlers: - self.handlers[handler_group_id].append(handler) - else: - self.handlers[handler_group_id] = [handler] - - async def clear_handlers(self, handler_group_id): - self.handlers.pop(handler_group_id, None) - - async def get_handlers(self, handler_group_id): - return self.handlers.pop(handler_group_id, None) - - async def load_handlers(self, filename, del_file_after_loading): - raise NotImplementedError() - - -class FileHandlerBackend(HandlerBackend): - def __init__(self, handlers=None, filename='./.handler-saves/handlers.save', delay=120): - super(FileHandlerBackend, self).__init__(handlers) - self.filename = filename - self.delay = delay - self.timer = threading.Timer(delay, self.save_handlers) - - async def register_handler(self, handler_group_id, handler): - if handler_group_id in self.handlers: - self.handlers[handler_group_id].append(handler) - else: - self.handlers[handler_group_id] = [handler] - await self.start_save_timer() - - async def clear_handlers(self, handler_group_id): - self.handlers.pop(handler_group_id, None) - await self.start_save_timer() - - async def get_handlers(self, handler_group_id): - handlers = self.handlers.pop(handler_group_id, None) - await self.start_save_timer() - return handlers - - async def start_save_timer(self): - if not self.timer.is_alive(): - if self.delay <= 0: - self.save_handlers() - else: - self.timer = threading.Timer(self.delay, self.save_handlers) - self.timer.start() - - async def save_handlers(self): - await self.dump_handlers(self.handlers, self.filename) - - async def load_handlers(self, filename=None, del_file_after_loading=True): - if not filename: - filename = self.filename - tmp = await self.return_load_handlers(filename, del_file_after_loading=del_file_after_loading) - if tmp is not None: - self.handlers.update(tmp) - - @staticmethod - async def dump_handlers(handlers, filename, file_mode="wb"): - dirs = filename.rsplit('/', maxsplit=1)[0] - os.makedirs(dirs, exist_ok=True) - - with open(filename + ".tmp", file_mode) as file: - if (apihelper.CUSTOM_SERIALIZER is None): - pickle.dump(handlers, file) - else: - apihelper.CUSTOM_SERIALIZER.dump(handlers, file) - - if os.path.isfile(filename): - os.remove(filename) - - os.rename(filename + ".tmp", filename) - - @staticmethod - async def return_load_handlers(filename, del_file_after_loading=True): - if os.path.isfile(filename) and os.path.getsize(filename) > 0: - with open(filename, "rb") as file: - if (apihelper.CUSTOM_SERIALIZER is None): - handlers = pickle.load(file) - else: - handlers = apihelper.CUSTOM_SERIALIZER.load(file) - - if del_file_after_loading: - os.remove(filename) - - return handlers - - -class RedisHandlerBackend(HandlerBackend): - def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None): - super(RedisHandlerBackend, self).__init__(handlers) - from redis import Redis - self.prefix = prefix - self.redis = Redis(host, port, db, password) - - async def _key(self, handle_group_id): - return ':'.join((self.prefix, str(handle_group_id))) - - async def register_handler(self, handler_group_id, handler): - handlers = [] - value = self.redis.get(self._key(handler_group_id)) - if value: - handlers = pickle.loads(value) - handlers.append(handler) - self.redis.set(self._key(handler_group_id), pickle.dumps(handlers)) - - async def clear_handlers(self, handler_group_id): - self.redis.delete(self._key(handler_group_id)) - - async def get_handlers(self, handler_group_id): - handlers = None - value = self.redis.get(self._key(handler_group_id)) - if value: - handlers = pickle.loads(value) - self.clear_handlers(handler_group_id) - return handlers class StateMemory: @@ -341,3 +201,19 @@ class StateFileContext: await self.obj._save_data(old_data) return + + +class BaseMiddleware: + """ + Base class for middleware. + + Your middlewares should be inherited from this class. + """ + def __init__(self): + pass + + async def pre_process(self, message, data): + raise NotImplementedError + async def post_process(self, message, data, exception): + raise NotImplementedError + diff --git a/telebot/asyncio_helper.py b/telebot/asyncio_helper.py index 7bb649e..3d1189d 100644 --- a/telebot/asyncio_helper.py +++ b/telebot/asyncio_helper.py @@ -1,28 +1,19 @@ -import asyncio +import asyncio # for future uses from time import time import aiohttp from telebot import types import json -import logging try: import ujson as json except ImportError: import json -import requests -from requests.exceptions import HTTPError, ConnectionError, Timeout - -try: - # noinspection PyUnresolvedReferences - from requests.packages.urllib3 import fields - format_header_param = fields.format_header_param -except ImportError: - format_header_param = None - API_URL = 'https://api.telegram.org/bot{0}/{1}' from datetime import datetime + +import telebot from telebot import util class SessionBase: @@ -44,19 +35,16 @@ READ_TIMEOUT = 30 LONG_POLLING_TIMEOUT = 10 # Should be positive, short polling should be used for testing purposes only (https://core.telegram.org/bots/api#getupdates) +logger = telebot.logger RETRY_ON_ERROR = False RETRY_TIMEOUT = 2 MAX_RETRIES = 15 -CUSTOM_SERIALIZER = None -CUSTOM_REQUEST_SENDER = None - -ENABLE_MIDDLEWARE = False - -async def _process_request(token, url, method='get', params=None, files=None): +async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None): async with await session_manager._get_new_session() as session: - async with session.get(API_URL.format(token, url), params=params, data=files) as response: + async with session.get(API_URL.format(token, url), params=params, data=files, timeout=request_timeout) as response: + logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout).replace(token, token.split(':')[0] + ":{TOKEN}")) json_result = await _check_result(url, response) if json_result: return json_result['result'] @@ -155,7 +143,7 @@ async def get_webhook_info(token, timeout=None): async def get_updates(token, offset=None, limit=None, - timeout=None, allowed_updates=None, long_polling_timeout=None): + timeout=None, allowed_updates=None, request_timeout=None): method_name = 'getUpdates' params = {} if offset: @@ -166,8 +154,7 @@ async def get_updates(token, offset=None, limit=None, params['timeout'] = timeout elif allowed_updates: params['allowed_updates'] = allowed_updates - params['long_polling_timeout'] = long_polling_timeout if long_polling_timeout else LONG_POLLING_TIMEOUT - return await _process_request(token, method_name, params=params) + return await _process_request(token, method_name, params=params, request_timeout=request_timeout) async def _check_result(method_name, result): """