Merge pull request #1722 from Badiboy/master

Handlers and Middlewares processing union
This commit is contained in:
Badiboy 2022-09-24 22:16:33 +03:00 committed by GitHub
commit 7c9b01b10a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 134 additions and 105 deletions

View File

@ -497,6 +497,7 @@ class TeleBot:
webhook_url = "{}://{}:{}/{}".format(protocol, listen, port, url_path) webhook_url = "{}://{}:{}/{}".format(protocol, listen, port, url_path)
if certificate and certificate_key: if certificate and certificate_key:
# noinspection PyTypeChecker
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(certificate, certificate_key) ssl_ctx.load_cert_chain(certificate, certificate_key)
@ -1115,11 +1116,6 @@ class TeleBot:
def _exec_task(self, task, *args, **kwargs): def _exec_task(self, task, *args, **kwargs):
if kwargs:
if kwargs.pop('task_type', "") == 'handler':
if kwargs.pop('pass_bot', False):
kwargs['bot'] = self
if self.threaded: if self.threaded:
self.worker_pool.put(task, *args, **kwargs) self.worker_pool.put(task, *args, **kwargs)
else: else:
@ -4894,8 +4890,14 @@ class TeleBot:
if not isinstance(regexp, str): if not isinstance(regexp, str):
logger.error(f"{method_name}: Regexp filter should be string. Not able to use the supplied type.") logger.error(f"{method_name}: Regexp filter should be string. Not able to use the supplied type.")
def message_handler(self, commands: Optional[List[str]]=None, regexp: Optional[str]=None, func: Optional[Callable]=None, def message_handler(
content_types: Optional[List[str]]=None, chat_types: Optional[List[str]]=None, **kwargs): self,
commands: Optional[List[str]]=None,
regexp: Optional[str]=None,
func: Optional[Callable]=None,
content_types: Optional[List[str]]=None,
chat_types: Optional[List[str]]=None,
**kwargs):
""" """
Handles New incoming message of any kind - text, photo, sticker, etc. Handles New incoming message of any kind - text, photo, sticker, etc.
As a parameter to the decorator function, it passes :class:`telebot.types.Message` object. As a parameter to the decorator function, it passes :class:`telebot.types.Message` object.
@ -5979,7 +5981,7 @@ class TeleBot:
return False return False
# middleware check-up method # middleware check-up method
def _check_middleware(self, update_type): def _get_middlewares(self, update_type):
""" """
Check middleware Check middleware
@ -5993,100 +5995,115 @@ class TeleBot:
def _run_middlewares_and_handler(self, message, handlers, middlewares, update_type): def _run_middlewares_and_handler(self, message, handlers, middlewares, update_type):
""" """
This class is made to run handler and middleware in queue. This method is made to run handlers and middlewares in queue.
:param handler: handler that should be executed. :param message: received message (update part) to process with handlers and/or middlewares
:param middleware: middleware that should be executed. :param handlers: all created handlers (not filtered)
:param middlewares: middlewares that should be executed (already filtered)
:param update_type: handler/update type (Update field name)
:return: :return:
""" """
data = {}
params =[]
handler_error = None
skip_handlers = False
if middlewares: if not self.use_class_middlewares:
for middleware in middlewares: if handlers:
if middleware.update_sensitive:
if hasattr(middleware, f'pre_process_{update_type}'):
result = getattr(middleware, f'pre_process_{update_type}')(message, data)
else:
logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type))
result = None
else:
result = middleware.pre_process(message, data)
# We will break this loop if CancelUpdate is returned
# Also, we will not run other middlewares
if isinstance(result, CancelUpdate):
return
elif isinstance(result, SkipHandler):
skip_handlers = True
if handlers and not(skip_handlers):
try:
for handler in handlers: for handler in handlers:
process_handler = self._test_message_handler(handler, message) if self._test_message_handler(handler, message):
if not process_handler: continue if handler.get('pass_bot', False):
for i in inspect.signature(handler['function']).parameters: handler['function'](message, bot = self)
params.append(i)
if len(params) == 1:
handler['function'](message)
elif "data" in params:
if len(params) == 2:
handler['function'](message, data)
elif len(params) == 3:
handler['function'](message, data=data, bot=self)
else: else:
logger.error("It is not allowed to pass data and values inside data to the handler. Check your handler: {}".format(handler['function'])) handler['function'](message)
return break
else: else:
data_copy = data.copy() data = {}
for key in list(data_copy): params =[]
# remove data from data_copy if handler does not accept it handler_error = None
if key not in params: skip_handlers = False
del data_copy[key]
if handler.get('pass_bot'):
data_copy["bot"] = self
if len(data_copy) > len(params) - 1: # remove the message parameter
logger.error("You are passing more parameters than the handler needs. Check your handler: {}".format(handler['function']))
return
handler["function"](message, **data_copy)
break
except Exception as e:
handler_error = e
if self.exception_handler:
self.exception_handler.handle(e)
else:
logging.error(str(e))
logger.debug("Exception traceback:\n%s", traceback.format_exc())
if middlewares: if middlewares:
for middleware in middlewares: for middleware in middlewares:
if middleware.update_sensitive: if middleware.update_sensitive:
if hasattr(middleware, f'post_process_{update_type}'): if hasattr(middleware, f'pre_process_{update_type}'):
getattr(middleware, f'post_process_{update_type}')(message, data, handler_error) result = getattr(middleware, f'pre_process_{update_type}')(message, data)
else:
logger.error('Middleware {} does not have pre_process_{} method. pre_process function execution was skipped.'.format(middleware.__class__.__name__, update_type))
result = None
else: else:
logger.error("Middleware: {} does not have post_process_{} method. Post process function was not executed.".format(middleware.__class__.__name__, update_type)) result = middleware.pre_process(message, data)
else: # We will break this loop if CancelUpdate is returned
middleware.post_process(message, data, handler_error) # Also, we will not run other middlewares
if isinstance(result, CancelUpdate):
return
elif isinstance(result, SkipHandler):
skip_handlers = True
if handlers and not(skip_handlers):
try:
for handler in handlers:
process_handler = self._test_message_handler(handler, message)
if not process_handler: continue
for i in inspect.signature(handler['function']).parameters:
params.append(i)
if len(params) == 1:
handler['function'](message)
elif "data" in params:
if len(params) == 2:
handler['function'](message, data)
elif len(params) == 3:
handler['function'](message, data=data, bot=self)
else:
logger.error("It is not allowed to pass data and values inside data to the handler. Check your handler: {}".format(handler['function']))
return
else:
data_copy = data.copy()
for key in list(data_copy):
# remove data from data_copy if handler does not accept it
if key not in params:
del data_copy[key]
if handler.get('pass_bot'):
data_copy["bot"] = self
if len(data_copy) > len(params) - 1: # remove the message parameter
logger.error("You are passing more parameters than the handler needs. Check your handler: {}".format(handler['function']))
return
handler["function"](message, **data_copy)
break
except Exception as e:
handler_error = e
if self.exception_handler:
self.exception_handler.handle(e)
else:
logger.error(str(e))
logger.debug("Exception traceback:\n%s", traceback.format_exc())
if middlewares:
for middleware in middlewares:
if middleware.update_sensitive:
if hasattr(middleware, f'post_process_{update_type}'):
getattr(middleware, f'post_process_{update_type}')(message, data, handler_error)
else:
logger.error("Middleware: {} does not have post_process_{} method. Post process function was not executed.".format(middleware.__class__.__name__, update_type))
else:
middleware.post_process(message, data, handler_error)
def _notify_command_handlers(self, handlers, new_messages, update_type): def _notify_command_handlers(self, handlers, new_messages, update_type):
""" """
Notifies command handlers. Notifies command handlers.
:param handlers: :param handlers: all created handlers
:param new_messages: :param new_messages: received messages to proceed
:param update_type: handler/update type (Update fields)
:return: :return:
""" """
if not(handlers) and not(self.use_class_middlewares): if not(handlers) and not(self.use_class_middlewares):
return return
if self.use_class_middlewares:
middlewares = self._get_middlewares(update_type)
else:
middlewares = None
for message in new_messages: for message in new_messages:
if not self.use_class_middlewares: self._exec_task(
for message_handler in handlers: self._run_middlewares_and_handler,
if self._test_message_handler(message_handler, message): message,
self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler') handlers=handlers,
break middlewares=middlewares,
else: update_type=update_type)
middleware = self._check_middleware(update_type)
self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware, update_type=update_type)
return

View File

@ -359,7 +359,8 @@ class AsyncTeleBot:
await self.close_session() await self.close_session()
logger.warning('Polling is stopped.') logger.warning('Polling is stopped.')
def _loop_create_task(self, coro): @staticmethod
def _loop_create_task(coro):
return asyncio.create_task(coro) return asyncio.create_task(coro)
async def _process_updates(self, handlers, messages, update_type): async def _process_updates(self, handlers, messages, update_type):
@ -371,12 +372,22 @@ class AsyncTeleBot:
:return: :return:
""" """
tasks = [] tasks = []
middlewares = await self._get_middlewares(update_type)
for message in messages: for message in messages:
middleware = await self.process_middlewares(update_type) tasks.append(self._run_middlewares_and_handlers(message, handlers, middlewares, update_type))
tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware, update_type))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def _run_middlewares_and_handlers(self, handlers, message, middlewares, update_type): async def _run_middlewares_and_handlers(self, message, handlers, middlewares, update_type):
"""
This method is made to run handlers and middlewares in queue.
:param message: received message (update part) to process with handlers and/or middlewares
:param handlers: all created handlers (not filtered)
:param middlewares: middlewares that should be executed (already filtered)
:param update_type: handler/update type (Update field name)
:return:
"""
handler_error = None handler_error = None
data = {} data = {}
skip_handlers = False skip_handlers = False
@ -446,7 +457,7 @@ class AsyncTeleBot:
else: else:
logger.error('Middleware {} does not have post_process_{} method. post_process function execution was skipped.'.format(middleware.__class__.__name__, update_type)) logger.error('Middleware {} does not have post_process_{} method. post_process function execution was skipped.'.format(middleware.__class__.__name__, update_type))
else: await middleware.post_process(message, data, handler_error) else: await middleware.post_process(message, data, handler_error)
# update handling
async def process_new_updates(self, updates: List[types.Update]): async def process_new_updates(self, updates: List[types.Update]):
""" """
Process new updates. Process new updates.
@ -635,7 +646,7 @@ class AsyncTeleBot:
""" """
await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request') await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
async def process_middlewares(self, update_type): async def _get_middlewares(self, update_type):
""" """
:meta private: :meta private:
""" """

View File

@ -1,6 +1,5 @@
""" """
This file is used by TeleBot.run_webhooks() function. This file is used by TeleBot.run_webhooks() function.
Fastapi is required to run this script. Fastapi is required to run this script.
""" """
@ -15,15 +14,11 @@ try:
except ImportError: except ImportError:
fastapi_installed = False fastapi_installed = False
from telebot.types import Update from telebot.types import Update
from typing import Optional from typing import Optional
class SyncWebhookListener: class SyncWebhookListener:
def __init__(self, bot, def __init__(self, bot,
secret_token: str, host: Optional[str]="127.0.0.1", secret_token: str, host: Optional[str]="127.0.0.1",
@ -33,13 +28,13 @@ class SyncWebhookListener:
debug: Optional[bool]=False debug: Optional[bool]=False
) -> None: ) -> None:
""" """
Aynchronous implementation of webhook listener Synchronous implementation of webhook listener
for asynchronous version of telebot. for synchronous version of telebot.
Not supposed to be used manually by user. Not supposed to be used manually by user.
Use AsyncTeleBot.run_webhooks() instead. Use TeleBot.run_webhooks() instead.
:param bot: AsyncTeleBot instance. :param bot: TeleBot instance.
:type bot: telebot.async_telebot.AsyncTeleBot :type bot: telebot.TeleBot
:param secret_token: Telegram secret token :param secret_token: Telegram secret token
:type secret_token: str :type secret_token: str
@ -77,7 +72,8 @@ class SyncWebhookListener:
self._prepare_endpoint_urls() self._prepare_endpoint_urls()
def _check_dependencies(self): @staticmethod
def _check_dependencies():
if not fastapi_installed: if not fastapi_installed:
raise ImportError('Fastapi or uvicorn is not installed. Please install it via pip.') raise ImportError('Fastapi or uvicorn is not installed. Please install it via pip.')

View File

@ -41,7 +41,10 @@ class StateStorageBase:
def get_state(self, chat_id, user_id): def get_state(self, chat_id, user_id):
raise NotImplementedError raise NotImplementedError
def get_interactive_data(self, chat_id, user_id):
raise NotImplementedError
def save(self, chat_id, user_id, data): def save(self, chat_id, user_id, data):
raise NotImplementedError raise NotImplementedError

View File

@ -3,6 +3,7 @@ from telebot.storage.base_storage import StateStorageBase, StateContext
class StateMemoryStorage(StateStorageBase): class StateMemoryStorage(StateStorageBase):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__()
self.data = {} self.data = {}
# #
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}

View File

@ -5,8 +5,8 @@ import pickle
class StatePickleStorage(StateStorageBase): class StatePickleStorage(StateStorageBase):
# noinspection PyMissingConstructor
def __init__(self, file_path="./.state-save/states.pkl") -> None: def __init__(self, file_path="./.state-save/states.pkl") -> None:
super().__init__()
self.file_path = file_path self.file_path = file_path
self.create_dir() self.create_dir()
self.data = self.read() self.data = self.read()

View File

@ -16,6 +16,7 @@ class StateRedisStorage(StateStorageBase):
TeleBot(storage=StateRedisStorage()) TeleBot(storage=StateRedisStorage())
""" """
def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'): def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'):
super().__init__()
self.redis = ConnectionPool(host=host, port=port, db=db, password=password) self.redis = ConnectionPool(host=host, port=port, db=db, password=password)
#self.con = Redis(connection_pool=self.redis) -> use this when necessary #self.con = Redis(connection_pool=self.redis) -> use this when necessary
# #