1
0
mirror of https://github.com/eternnoir/pyTelegramBotAPI.git synced 2023-08-10 21:12:57 +03:00

Multiple middlewares allowed for async

This commit is contained in:
coder2020official 2022-03-07 13:30:39 +05:00
parent f69a2ba044
commit be0557c2b5
2 changed files with 34 additions and 24 deletions

View File

@ -3088,6 +3088,8 @@ class TeleBot:
logger.warning("register_message_handler: 'content_types' filter should be List of strings (content types), not string.") logger.warning("register_message_handler: 'content_types' filter should be List of strings (content types), not string.")
content_types = [content_types] content_types = [content_types]
handler_dict = self._build_handler_dict(callback, handler_dict = self._build_handler_dict(callback,
chat_types=chat_types, chat_types=chat_types,
content_types=content_types, content_types=content_types,

View File

@ -268,21 +268,23 @@ class AsyncTeleBot:
""" """
tasks = [] tasks = []
for message in messages: for message in messages:
middleware = await self.process_middlewares(message, update_type) middleware = await self.process_middlewares(update_type)
tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware)) tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def _run_middlewares_and_handlers(self, handlers, message, middleware): async def _run_middlewares_and_handlers(self, handlers, message, middlewares):
handler_error = None handler_error = None
data = {} data = {}
process_handler = True process_handler = True
if middleware:
middleware_result = await middleware.pre_process(message, data) if middlewares:
if isinstance(middleware_result, SkipHandler): for middleware in middlewares:
await middleware.post_process(message, data, handler_error) middleware_result = await middleware.pre_process(message, data)
process_handler = False if isinstance(middleware_result, SkipHandler):
if isinstance(middleware_result, CancelUpdate): await middleware.post_process(message, data, handler_error)
return process_handler = False
if isinstance(middleware_result, CancelUpdate):
return
for handler in handlers: for handler in handlers:
if not process_handler: if not process_handler:
break break
@ -299,15 +301,20 @@ class AsyncTeleBot:
if len(params) == 1: if len(params) == 1:
await handler['function'](message) await handler['function'](message)
break break
if params[1] == 'data' and handler.get('pass_bot') is True: elif len(params) == 2:
await handler['function'](message, data, self) if handler['pass_bot']:
break await handler['function'](message, self)
elif params[1] == 'data' and handler.get('pass_bot') is False: break
await handler['function'](message, data) else:
break await handler['function'](message, data)
elif params[1] != 'data' and handler.get('pass_bot') is True: break
await handler['function'](message, self) elif len(params) == 3:
break if handler['pass_bot'] and params[1] == 'bot':
await handler['function'](message, self, data)
break
else:
await handler['function'](message, data)
break
except Exception as e: except Exception as e:
handler_error = e handler_error = e
@ -317,8 +324,9 @@ class AsyncTeleBot:
logging.error(str(e)) logging.error(str(e))
return return
if middleware: if middlewares:
await middleware.post_process(message, data, handler_error) for middleware in middlewares:
await middleware.post_process(message, data, handler_error)
# update handling # update handling
async def process_new_updates(self, updates): async def process_new_updates(self, updates):
""" """
@ -463,10 +471,10 @@ class AsyncTeleBot:
async def process_chat_join_request(self, chat_join_request): async def process_chat_join_request(self, chat_join_request):
await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request') await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
async def process_middlewares(self, update, update_type): async def process_middlewares(self, update_type):
for middleware in self.middlewares: if self.middlewares:
if update_type in middleware.update_types: middlewares = [middleware for middleware in self.middlewares if update_type in middleware.update_types]
return middleware return middlewares
return None return None
async def __notify_update(self, new_messages): async def __notify_update(self, new_messages):