diff --git a/telebot/__init__.py b/telebot/__init__.py index 567cfcb..800d2f6 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -16,12 +16,16 @@ import telebot.types # storage from telebot.storage import StatePickleStorage, StateMemoryStorage + + logger = logging.getLogger('TeleBot') formatter = logging.Formatter( '%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"' ) +import inspect + console_output_handler = logging.StreamHandler(sys.stderr) console_output_handler.setFormatter(formatter) logger.addHandler(console_output_handler) @@ -29,7 +33,7 @@ logger.addHandler(console_output_handler) logger.setLevel(logging.ERROR) from telebot import apihelper, util, types -from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend +from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, BaseMiddleware, CancelUpdate, SkipHandler from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter @@ -147,7 +151,7 @@ class TeleBot: def __init__( self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2, next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0, - suppress_middleware_excepions=False, state_storage=StateMemoryStorage() + suppress_middleware_excepions=False, state_storage=StateMemoryStorage(), use_class_middlewares=False ): """ :param token: bot API token @@ -193,7 +197,8 @@ class TeleBot: self.current_states = state_storage - if apihelper.ENABLE_MIDDLEWARE: + self.use_class_middlewares = use_class_middlewares + if apihelper.ENABLE_MIDDLEWARE and not use_class_middlewares: self.typed_middleware_handlers = { 'message': [], 'edited_message': [], @@ -211,6 +216,13 @@ class TeleBot: 'chat_join_request': [] } self.default_middleware_handlers = [] + if apihelper.ENABLE_MIDDLEWARE and use_class_middlewares: + logger.warning( + 'You are using class based middlewares, but you have ' + 'ENABLE_MIDDLEWARE set to True. This is not recommended.' + ) + self.middlewares = [] if use_class_middlewares else None + self.threaded = threaded if self.threaded: @@ -440,6 +452,7 @@ class TeleBot: else: if update.update_id > self.last_update_id: self.last_update_id = update.update_id continue + if update.update_id > self.last_update_id: self.last_update_id = update.update_id @@ -519,46 +532,46 @@ class TeleBot: self._notify_next_handlers(new_messages) self._notify_reply_handlers(new_messages) self.__notify_update(new_messages) - self._notify_command_handlers(self.message_handlers, new_messages) + self._notify_command_handlers(self.message_handlers, new_messages, 'message') def process_new_edited_messages(self, edited_message): - self._notify_command_handlers(self.edited_message_handlers, edited_message) + self._notify_command_handlers(self.edited_message_handlers, edited_message, 'edited_message') def process_new_channel_posts(self, channel_post): - self._notify_command_handlers(self.channel_post_handlers, channel_post) + self._notify_command_handlers(self.channel_post_handlers, channel_post, 'channel_post') def process_new_edited_channel_posts(self, edited_channel_post): - self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post) + self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post') def process_new_inline_query(self, new_inline_querys): - self._notify_command_handlers(self.inline_handlers, new_inline_querys) + self._notify_command_handlers(self.inline_handlers, new_inline_querys, 'inline_query') def process_new_chosen_inline_query(self, new_chosen_inline_querys): - self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys) + self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query') def process_new_callback_query(self, new_callback_querys): - self._notify_command_handlers(self.callback_query_handlers, new_callback_querys) + self._notify_command_handlers(self.callback_query_handlers, new_callback_querys, 'callback_query') def process_new_shipping_query(self, new_shipping_querys): - self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys) + self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys, 'shipping_query') def process_new_pre_checkout_query(self, pre_checkout_querys): - self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys) + self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query') def process_new_poll(self, polls): - self._notify_command_handlers(self.poll_handlers, polls) + self._notify_command_handlers(self.poll_handlers, polls, 'poll') def process_new_poll_answer(self, poll_answers): - self._notify_command_handlers(self.poll_answer_handlers, poll_answers) + self._notify_command_handlers(self.poll_answer_handlers, poll_answers, 'poll_answer') def process_new_my_chat_member(self, my_chat_members): - self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members) + self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members, 'my_chat_member') def process_new_chat_member(self, chat_members): - self._notify_command_handlers(self.chat_member_handlers, chat_members) + self._notify_command_handlers(self.chat_member_handlers, chat_members, 'chat_member') def process_new_chat_join_request(self, chat_join_request): - self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request) + self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request, 'chat_join_request') def process_middlewares(self, update): for update_type, middlewares in self.typed_middleware_handlers.items(): @@ -2535,6 +2548,20 @@ class TeleBot: chat_id = message.chat.id self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs) + + def setup_middleware(self, middleware: BaseMiddleware): + """ + Register middleware + :param middleware: Subclass of `telebot.handler_backends.BaseMiddleware` + :return: None + """ + if not self.use_class_middlewares: + logger.warning('Middleware is not enabled. Pass use_class_middlewares=True to enable it.') + return + self.middlewares.append(middleware) + + + def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None: """ Sets a new state of a user. @@ -3500,16 +3527,95 @@ class TeleBot: logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.") return False - def _notify_command_handlers(self, handlers, new_messages): + # middleware check-up method + def _check_middleware(self, update_type): + """ + Check middleware + :param message: + :return: + """ + if self.middlewares: middlewares = [i for i in self.middlewares if update_type in i.update_types] + if not middlewares: return + return middlewares + + def _run_middlewares_and_handler(self, message, handlers, middlewares, *args, **kwargs): + """This class is made to run handler and middleware in queue. + :param handler: handler that should be executed. + :param middleware: middleware that should be executed. + :return: + """ + data = {} + params =[] + handler_error = None + skip_handler = False + if middlewares: + for middleware in middlewares: + result = middleware.pre_process(message, data) + # We will break this loop if CancelUpdate is returned + # Also, we will not run other middlewares + if isinstance(result, CancelUpdate): + return + elif isinstance(result, SkipHandler) and skip_handler is False: + skip_handler = True + + + + + try: + if handlers and not skip_handler: + for handler in handlers: + process_handler = self._test_message_handler(handler, message) + if not process_handler: continue + else: + for i in inspect.signature(handler['function']).parameters: + params.append(i) + if len(params) == 1: + handler['function'](message) + + elif len(params) == 2: + if handler.get('pass_bot') is True: + handler['function'](message, self) + + elif handler.get('pass_bot') is False: + handler['function'](message, data) + + elif len(params) == 3: + if params[2] == 'bot' and handler.get('pass_bot') is True: + handler['function'](message, data, self) + + else: + handler['function'](message, self, data) + + except Exception as e: + handler_error = e + + if not middlewares: + if self.exception_handler: + return self.exception_handler.handle(e) + logging.error(str(e)) + return + if middlewares: + for middleware in middlewares: + middleware.post_process(message, data, handler_error) + + + + + def _notify_command_handlers(self, handlers, new_messages, update_type): """ Notifies command handlers :param handlers: :param new_messages: :return: """ - if len(handlers) == 0: + if len(handlers) == 0 and not self.use_class_middlewares: return + for message in new_messages: + middleware = self._check_middleware(update_type) + if self.use_class_middlewares and middleware: + self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware) + return for message_handler in handlers: if self._test_message_handler(message_handler, message): self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler') diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index d88457b..f696d6b 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -165,4 +165,43 @@ class StatesGroup: # change value of that variable value.name = ':'.join((cls.__name__, name)) - \ No newline at end of file + +class BaseMiddleware: + """ + Base class for middleware. + Your middlewares should be inherited from this class. + """ + + def __init__(self): + pass + + def pre_process(self, message, data): + raise NotImplementedError + + def post_process(self, message, data, exception): + raise NotImplementedError + + +class SkipHandler: + """ + Class for skipping handlers. + Just return instance of this class + in middleware to skip handler. + Update will go to post_process, + but will skip execution of handler. + """ + + def __init__(self) -> None: + pass + +class CancelUpdate: + """ + Class for canceling updates. + Just return instance of this class + in middleware to skip update. + Update will skip handler and execution + of post_process in middlewares. + """ + + def __init__(self) -> None: + pass \ No newline at end of file