diff --git a/telebot/__init__.py b/telebot/__init__.py index fa4ee13..c39ce2a 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -2979,6 +2979,15 @@ class TeleBot: if not self.use_class_middlewares: logger.warning('Class-based middlewares are not enabled. Pass use_class_middlewares=True to enable it.') return + + if not hasattr(middleware, 'update_types'): + logger.error('Middleware has no update_types parameter. Please add list of updates to handle.') + return + + if not hasattr(middleware, 'update_sensitive'): + logger.warning('Middleware has no update_sensitive parameter. Parameter was set to False.') + middleware.update_sensitive = False + self.middlewares.append(middleware) def set_state(self, user_id: int, state: Union[int, str, State], chat_id: int=None) -> None: @@ -4037,7 +4046,7 @@ class TeleBot: middlewares = [i for i in self.middlewares if update_type in i.update_types] return middlewares - def _run_middlewares_and_handler(self, message, handlers, middlewares): + def _run_middlewares_and_handler(self, message, handlers, middlewares, update_type): """ This class is made to run handler and middleware in queue. @@ -4051,7 +4060,14 @@ class TeleBot: skip_handler = False if middlewares: for middleware in middlewares: - result = middleware.pre_process(message, data) + if middleware.update_sensitive: + if hasattr(middleware, f'pre_process_{update_type}'): + result = getattr(middleware, f'pre_process_{update_type}')(message, data) + else: + logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + result = None + else: + result = middleware.pre_process(message, data) # We will break this loop if CancelUpdate is returned # Also, we will not run other middlewares if isinstance(result, CancelUpdate): @@ -4102,7 +4118,13 @@ class TeleBot: if middlewares: for middleware in middlewares: - middleware.post_process(message, data, handler_error) + if middleware.update_sensitive: + if hasattr(middleware, f'post_process_{update_type}'): + result = getattr(middleware, f'post_process_{update_type}')(message, data, handler_error) + else: + logger.error("Middleware: {} does not have post_process_{} method. Post process function was not executed.".format(middleware.__class__.__name__, update_type)) + else: + result = middleware.post_process(message, data, handler_error) def _notify_command_handlers(self, handlers, new_messages, update_type): """ @@ -4118,7 +4140,7 @@ class TeleBot: for message in new_messages: if self.use_class_middlewares: middleware = self._check_middleware(update_type) - self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware) + self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware, update_type=update_type) return else: for message_handler in handlers: diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index f1f1c83..2241316 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -14,7 +14,7 @@ import telebot.types # storages from telebot.asyncio_storage import StateMemoryStorage, StatePickleStorage -from telebot.asyncio_handler_backends import CancelUpdate, SkipHandler, State +from telebot.asyncio_handler_backends import BaseMiddleware, CancelUpdate, SkipHandler, State from inspect import signature @@ -288,17 +288,24 @@ class AsyncTeleBot: tasks = [] for message in messages: middleware = await self.process_middlewares(update_type) - tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware)) + tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware, update_type)) await asyncio.gather(*tasks) - async def _run_middlewares_and_handlers(self, handlers, message, middlewares): + async def _run_middlewares_and_handlers(self, handlers, message, middlewares, update_type): handler_error = None data = {} process_handler = True params = [] if middlewares: for middleware in middlewares: - middleware_result = await middleware.pre_process(message, data) + if middleware.update_sensitive: + if hasattr(middleware, f'pre_process_{update_type}'): + middleware_result = await getattr(middleware, f'pre_process_{update_type}')(message, data) + else: + logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + middleware_result = None + else: + middleware_result = await middleware.pre_process(message, data) if isinstance(middleware_result, SkipHandler): await middleware.post_process(message, data, handler_error) process_handler = False @@ -356,7 +363,12 @@ class AsyncTeleBot: if middlewares: for middleware in middlewares: - await middleware.post_process(message, data, handler_error) + if middleware.update_sensitive: + if hasattr(middleware, f'post_process_{update_type}'): + await getattr(middleware, f'post_process_{update_type}')(message, data, handler_error) + else: + logger.error('Middleware {} does not have post_process_{} method. post_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + else: await middleware.post_process(message, data, handler_error) # update handling async def process_new_updates(self, updates): """ @@ -597,13 +609,21 @@ class AsyncTeleBot: logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.") return False - def setup_middleware(self, middleware): + def setup_middleware(self, middleware: BaseMiddleware): """ Setup middleware. :param middleware: Middleware-class. :return: """ + if not hasattr(middleware, 'update_types'): + logger.error('Middleware has no update_types parameter. Please add list of updates to handle.') + return + + if not hasattr(middleware, 'update_sensitive'): + logger.warning('Middleware has no update_sensitive parameter. Parameter was set to False.') + middleware.update_sensitive = False + self.middlewares.append(middleware) def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs): diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 4ef4555..4f0d174 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -2,8 +2,35 @@ class BaseMiddleware: """ Base class for middleware. Your middlewares should be inherited from this class. + + Set update_sensitive=True if you want to get different updates on + different functions. For example, if you want to handle pre_process for + message update, then you will have to create pre_process_message function, and + so on. Same applies to post_process. + + .. code-block:: python + class MyMiddleware(BaseMiddleware): + def __init__(self): + self.update_sensitive = True + self.update_types = ['message', 'edited_message'] + + def pre_process_message(self, message, data): + # only message update here + pass + + def post_process_message(self, message, data, exception): + pass # only message update here for post_process + + def pre_process_edited_message(self, message, data): + # only edited_message update here + pass + + def post_process_edited_message(self, message, data, exception): + pass # only edited_message update here for post_process """ + update_sensitive: bool = False + def __init__(self): pass diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index 4304317..598cf59 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -173,8 +173,35 @@ class BaseMiddleware: """ Base class for middleware. Your middlewares should be inherited from this class. + + Set update_sensitive=True if you want to get different updates on + different functions. For example, if you want to handle pre_process for + message update, then you will have to create pre_process_message function, and + so on. Same applies to post_process. + + .. code-block:: python + class MyMiddleware(BaseMiddleware): + def __init__(self): + self.update_sensitive = True + self.update_types = ['message', 'edited_message'] + + def pre_process_message(self, message, data): + # only message update here + pass + + def post_process_message(self, message, data, exception): + pass # only message update here for post_process + + def pre_process_edited_message(self, message, data): + # only edited_message update here + pass + + def post_process_edited_message(self, message, data, exception): + pass # only edited_message update here for post_process """ + update_sensitive: bool = False + def __init__(self): pass