diff --git a/telebot/__init__.py b/telebot/__init__.py index 1c88c17..2d7b762 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -3008,8 +3008,17 @@ class TeleBot: :return: None """ if not self.use_class_middlewares: - logger.warning('Middleware is not enabled. Pass use_class_middlewares=True to enable it.') + logger.error('Middleware is not enabled. Pass use_class_middlewares=True to enable it.') return + + if not hasattr(middleware, 'update_types'): + logger.error('Middleware has no update_types parameter. Please add list of updates to handle.') + return + + if not hasattr(middleware, 'update_sensitive'): + logger.warning('Middleware has no update_sensitive parameter. Parameter was set to False.') + middleware.update_sensitive = False + self.middlewares.append(middleware) @@ -4065,7 +4074,7 @@ class TeleBot: middlewares = [i for i in self.middlewares if update_type in i.update_types] return middlewares - def _run_middlewares_and_handler(self, message, handlers, middlewares): + def _run_middlewares_and_handler(self, message, handlers, middlewares, update_type): """ This class is made to run handler and middleware in queue. @@ -4079,7 +4088,14 @@ class TeleBot: skip_handler = False if middlewares: for middleware in middlewares: - result = middleware.pre_process(message, data) + if middleware.update_sensitive: + if hasattr(middleware, f'pre_process_{update_type}'): + result = getattr(middleware, f'pre_process_{update_type}')(message, data) + else: + logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + result = None + else: + result = middleware.pre_process(message, data) # We will break this loop if CancelUpdate is returned # Also, we will not run other middlewares if isinstance(result, CancelUpdate): @@ -4134,7 +4150,13 @@ class TeleBot: # remove the bot from data if middlewares: for middleware in middlewares: - middleware.post_process(message, data, handler_error) + if middleware.update_sensitive: + if hasattr(middleware, f'post_process_{update_type}'): + result = getattr(middleware, f'post_process_{update_type}')(message, data, handler_error) + else: + logger.error("Middleware: {} does not have post_process_{} method. Post process function was not executed.".format(middleware.__class__.__name__, update_type)) + else: + result = middleware.post_process(message, data, handler_error) @@ -4153,7 +4175,7 @@ class TeleBot: for message in new_messages: if self.use_class_middlewares: middleware = self._check_middleware(update_type) - self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware) + self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware, update_type=update_type) return else: for message_handler in handlers: diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index e9ee889..769ca6e 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -14,7 +14,7 @@ import telebot.types # storages from telebot.asyncio_storage import StateMemoryStorage, StatePickleStorage -from telebot.asyncio_handler_backends import CancelUpdate, SkipHandler, State +from telebot.asyncio_handler_backends import BaseMiddleware, CancelUpdate, SkipHandler, State from inspect import signature @@ -288,17 +288,24 @@ class AsyncTeleBot: tasks = [] for message in messages: middleware = await self.process_middlewares(update_type) - tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware)) + tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware, update_type)) await asyncio.gather(*tasks) - async def _run_middlewares_and_handlers(self, handlers, message, middlewares): + async def _run_middlewares_and_handlers(self, handlers, message, middlewares, update_type): handler_error = None data = {} process_handler = True params = [] if middlewares: for middleware in middlewares: - middleware_result = await middleware.pre_process(message, data) + if middleware.update_sensitive: + if hasattr(middleware, f'pre_process_{update_type}'): + middleware_result = await getattr(middleware, f'pre_process_{update_type}')(message, data) + else: + logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + middleware_result = None + else: + middleware_result = await middleware.pre_process(message, data) if isinstance(middleware_result, SkipHandler): await middleware.post_process(message, data, handler_error) process_handler = False @@ -354,7 +361,12 @@ class AsyncTeleBot: if middlewares: for middleware in middlewares: - await middleware.post_process(message, data, handler_error) + if middleware.update_sensitive: + if hasattr(middleware, f'post_process_{update_type}'): + await getattr(middleware, f'post_process_{update_type}')(message, data, handler_error) + else: + logger.error('Middleware {} does not have post_process_{} method. post_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) + else: await middleware.post_process(message, data, handler_error) # update handling async def process_new_updates(self, updates): """ @@ -595,13 +607,21 @@ class AsyncTeleBot: logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.") return False - def setup_middleware(self, middleware): + def setup_middleware(self, middleware: BaseMiddleware): """ Setup middleware. :param middleware: Middleware-class. :return: """ + if not hasattr(middleware, 'update_types'): + logger.error('Middleware has no update_types parameter. Please add list of updates to handle.') + return + + if not hasattr(middleware, 'update_sensitive'): + logger.warning('Middleware has no update_sensitive parameter. Parameter was set to False.') + middleware.update_sensitive = False + self.middlewares.append(middleware) def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs): diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 4ef4555..4f0d174 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -2,8 +2,35 @@ class BaseMiddleware: """ Base class for middleware. Your middlewares should be inherited from this class. + + Set update_sensitive=True if you want to get different updates on + different functions. For example, if you want to handle pre_process for + message update, then you will have to create pre_process_message function, and + so on. Same applies to post_process. + + .. code-block:: python + class MyMiddleware(BaseMiddleware): + def __init__(self): + self.update_sensitive = True + self.update_types = ['message', 'edited_message'] + + def pre_process_message(self, message, data): + # only message update here + pass + + def post_process_message(self, message, data, exception): + pass # only message update here for post_process + + def pre_process_edited_message(self, message, data): + # only edited_message update here + pass + + def post_process_edited_message(self, message, data, exception): + pass # only edited_message update here for post_process """ + update_sensitive: bool = False + def __init__(self): pass diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index 4304317..598cf59 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -173,8 +173,35 @@ class BaseMiddleware: """ Base class for middleware. Your middlewares should be inherited from this class. + + Set update_sensitive=True if you want to get different updates on + different functions. For example, if you want to handle pre_process for + message update, then you will have to create pre_process_message function, and + so on. Same applies to post_process. + + .. code-block:: python + class MyMiddleware(BaseMiddleware): + def __init__(self): + self.update_sensitive = True + self.update_types = ['message', 'edited_message'] + + def pre_process_message(self, message, data): + # only message update here + pass + + def post_process_message(self, message, data, exception): + pass # only message update here for post_process + + def pre_process_edited_message(self, message, data): + # only edited_message update here + pass + + def post_process_edited_message(self, message, data, exception): + pass # only edited_message update here for post_process """ + update_sensitive: bool = False + def __init__(self): pass