mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
Added middlewares.
Bumped middlewares
This commit is contained in:
parent
ac12d0fc02
commit
388477686b
@ -16,12 +16,16 @@ import telebot.types
|
|||||||
# storage
|
# storage
|
||||||
from telebot.storage import StatePickleStorage, StateMemoryStorage
|
from telebot.storage import StatePickleStorage, StateMemoryStorage
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('TeleBot')
|
logger = logging.getLogger('TeleBot')
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"'
|
'%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
console_output_handler = logging.StreamHandler(sys.stderr)
|
console_output_handler = logging.StreamHandler(sys.stderr)
|
||||||
console_output_handler.setFormatter(formatter)
|
console_output_handler.setFormatter(formatter)
|
||||||
logger.addHandler(console_output_handler)
|
logger.addHandler(console_output_handler)
|
||||||
@ -29,7 +33,7 @@ logger.addHandler(console_output_handler)
|
|||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
from telebot import apihelper, util, types
|
from telebot import apihelper, util, types
|
||||||
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend
|
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, BaseMiddleware, CancelUpdate, SkipHandler
|
||||||
from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter
|
from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter
|
||||||
|
|
||||||
|
|
||||||
@ -147,7 +151,7 @@ class TeleBot:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2,
|
self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2,
|
||||||
next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0,
|
next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0,
|
||||||
suppress_middleware_excepions=False, state_storage=StateMemoryStorage()
|
suppress_middleware_excepions=False, state_storage=StateMemoryStorage(), use_class_middlewares=False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param token: bot API token
|
:param token: bot API token
|
||||||
@ -193,7 +197,8 @@ class TeleBot:
|
|||||||
|
|
||||||
self.current_states = state_storage
|
self.current_states = state_storage
|
||||||
|
|
||||||
if apihelper.ENABLE_MIDDLEWARE:
|
self.use_class_middlewares = use_class_middlewares
|
||||||
|
if apihelper.ENABLE_MIDDLEWARE and not use_class_middlewares:
|
||||||
self.typed_middleware_handlers = {
|
self.typed_middleware_handlers = {
|
||||||
'message': [],
|
'message': [],
|
||||||
'edited_message': [],
|
'edited_message': [],
|
||||||
@ -211,6 +216,13 @@ class TeleBot:
|
|||||||
'chat_join_request': []
|
'chat_join_request': []
|
||||||
}
|
}
|
||||||
self.default_middleware_handlers = []
|
self.default_middleware_handlers = []
|
||||||
|
if apihelper.ENABLE_MIDDLEWARE and use_class_middlewares:
|
||||||
|
logger.warning(
|
||||||
|
'You are using class based middlewares, but you have '
|
||||||
|
'ENABLE_MIDDLEWARE set to True. This is not recommended.'
|
||||||
|
)
|
||||||
|
self.middlewares = [] if use_class_middlewares else None
|
||||||
|
|
||||||
|
|
||||||
self.threaded = threaded
|
self.threaded = threaded
|
||||||
if self.threaded:
|
if self.threaded:
|
||||||
@ -441,6 +453,7 @@ class TeleBot:
|
|||||||
if update.update_id > self.last_update_id: self.last_update_id = update.update_id
|
if update.update_id > self.last_update_id: self.last_update_id = update.update_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
if update.update_id > self.last_update_id:
|
if update.update_id > self.last_update_id:
|
||||||
self.last_update_id = update.update_id
|
self.last_update_id = update.update_id
|
||||||
if update.message:
|
if update.message:
|
||||||
@ -519,46 +532,46 @@ class TeleBot:
|
|||||||
self._notify_next_handlers(new_messages)
|
self._notify_next_handlers(new_messages)
|
||||||
self._notify_reply_handlers(new_messages)
|
self._notify_reply_handlers(new_messages)
|
||||||
self.__notify_update(new_messages)
|
self.__notify_update(new_messages)
|
||||||
self._notify_command_handlers(self.message_handlers, new_messages)
|
self._notify_command_handlers(self.message_handlers, new_messages, 'message')
|
||||||
|
|
||||||
def process_new_edited_messages(self, edited_message):
|
def process_new_edited_messages(self, edited_message):
|
||||||
self._notify_command_handlers(self.edited_message_handlers, edited_message)
|
self._notify_command_handlers(self.edited_message_handlers, edited_message, 'edited_message')
|
||||||
|
|
||||||
def process_new_channel_posts(self, channel_post):
|
def process_new_channel_posts(self, channel_post):
|
||||||
self._notify_command_handlers(self.channel_post_handlers, channel_post)
|
self._notify_command_handlers(self.channel_post_handlers, channel_post, 'channel_post')
|
||||||
|
|
||||||
def process_new_edited_channel_posts(self, edited_channel_post):
|
def process_new_edited_channel_posts(self, edited_channel_post):
|
||||||
self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post)
|
self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post')
|
||||||
|
|
||||||
def process_new_inline_query(self, new_inline_querys):
|
def process_new_inline_query(self, new_inline_querys):
|
||||||
self._notify_command_handlers(self.inline_handlers, new_inline_querys)
|
self._notify_command_handlers(self.inline_handlers, new_inline_querys, 'inline_query')
|
||||||
|
|
||||||
def process_new_chosen_inline_query(self, new_chosen_inline_querys):
|
def process_new_chosen_inline_query(self, new_chosen_inline_querys):
|
||||||
self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys)
|
self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query')
|
||||||
|
|
||||||
def process_new_callback_query(self, new_callback_querys):
|
def process_new_callback_query(self, new_callback_querys):
|
||||||
self._notify_command_handlers(self.callback_query_handlers, new_callback_querys)
|
self._notify_command_handlers(self.callback_query_handlers, new_callback_querys, 'callback_query')
|
||||||
|
|
||||||
def process_new_shipping_query(self, new_shipping_querys):
|
def process_new_shipping_query(self, new_shipping_querys):
|
||||||
self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys)
|
self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys, 'shipping_query')
|
||||||
|
|
||||||
def process_new_pre_checkout_query(self, pre_checkout_querys):
|
def process_new_pre_checkout_query(self, pre_checkout_querys):
|
||||||
self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys)
|
self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query')
|
||||||
|
|
||||||
def process_new_poll(self, polls):
|
def process_new_poll(self, polls):
|
||||||
self._notify_command_handlers(self.poll_handlers, polls)
|
self._notify_command_handlers(self.poll_handlers, polls, 'poll')
|
||||||
|
|
||||||
def process_new_poll_answer(self, poll_answers):
|
def process_new_poll_answer(self, poll_answers):
|
||||||
self._notify_command_handlers(self.poll_answer_handlers, poll_answers)
|
self._notify_command_handlers(self.poll_answer_handlers, poll_answers, 'poll_answer')
|
||||||
|
|
||||||
def process_new_my_chat_member(self, my_chat_members):
|
def process_new_my_chat_member(self, my_chat_members):
|
||||||
self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members)
|
self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members, 'my_chat_member')
|
||||||
|
|
||||||
def process_new_chat_member(self, chat_members):
|
def process_new_chat_member(self, chat_members):
|
||||||
self._notify_command_handlers(self.chat_member_handlers, chat_members)
|
self._notify_command_handlers(self.chat_member_handlers, chat_members, 'chat_member')
|
||||||
|
|
||||||
def process_new_chat_join_request(self, chat_join_request):
|
def process_new_chat_join_request(self, chat_join_request):
|
||||||
self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request)
|
self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
|
||||||
|
|
||||||
def process_middlewares(self, update):
|
def process_middlewares(self, update):
|
||||||
for update_type, middlewares in self.typed_middleware_handlers.items():
|
for update_type, middlewares in self.typed_middleware_handlers.items():
|
||||||
@ -2535,6 +2548,20 @@ class TeleBot:
|
|||||||
chat_id = message.chat.id
|
chat_id = message.chat.id
|
||||||
self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs)
|
self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middleware(self, middleware: BaseMiddleware):
|
||||||
|
"""
|
||||||
|
Register middleware
|
||||||
|
:param middleware: Subclass of `telebot.handler_backends.BaseMiddleware`
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if not self.use_class_middlewares:
|
||||||
|
logger.warning('Middleware is not enabled. Pass use_class_middlewares=True to enable it.')
|
||||||
|
return
|
||||||
|
self.middlewares.append(middleware)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None:
|
def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None:
|
||||||
"""
|
"""
|
||||||
Sets a new state of a user.
|
Sets a new state of a user.
|
||||||
@ -3500,16 +3527,95 @@ class TeleBot:
|
|||||||
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.")
|
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _notify_command_handlers(self, handlers, new_messages):
|
# middleware check-up method
|
||||||
|
def _check_middleware(self, update_type):
|
||||||
|
"""
|
||||||
|
Check middleware
|
||||||
|
:param message:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.middlewares: middlewares = [i for i in self.middlewares if update_type in i.update_types]
|
||||||
|
if not middlewares: return
|
||||||
|
return middlewares
|
||||||
|
|
||||||
|
def _run_middlewares_and_handler(self, message, handlers, middlewares, *args, **kwargs):
|
||||||
|
"""This class is made to run handler and middleware in queue.
|
||||||
|
:param handler: handler that should be executed.
|
||||||
|
:param middleware: middleware that should be executed.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
params =[]
|
||||||
|
handler_error = None
|
||||||
|
skip_handler = False
|
||||||
|
if middlewares:
|
||||||
|
for middleware in middlewares:
|
||||||
|
result = middleware.pre_process(message, data)
|
||||||
|
# We will break this loop if CancelUpdate is returned
|
||||||
|
# Also, we will not run other middlewares
|
||||||
|
if isinstance(result, CancelUpdate):
|
||||||
|
return
|
||||||
|
elif isinstance(result, SkipHandler) and skip_handler is False:
|
||||||
|
skip_handler = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if handlers and not skip_handler:
|
||||||
|
for handler in handlers:
|
||||||
|
process_handler = self._test_message_handler(handler, message)
|
||||||
|
if not process_handler: continue
|
||||||
|
else:
|
||||||
|
for i in inspect.signature(handler['function']).parameters:
|
||||||
|
params.append(i)
|
||||||
|
if len(params) == 1:
|
||||||
|
handler['function'](message)
|
||||||
|
|
||||||
|
elif len(params) == 2:
|
||||||
|
if handler.get('pass_bot') is True:
|
||||||
|
handler['function'](message, self)
|
||||||
|
|
||||||
|
elif handler.get('pass_bot') is False:
|
||||||
|
handler['function'](message, data)
|
||||||
|
|
||||||
|
elif len(params) == 3:
|
||||||
|
if params[2] == 'bot' and handler.get('pass_bot') is True:
|
||||||
|
handler['function'](message, data, self)
|
||||||
|
|
||||||
|
else:
|
||||||
|
handler['function'](message, self, data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
handler_error = e
|
||||||
|
|
||||||
|
if not middlewares:
|
||||||
|
if self.exception_handler:
|
||||||
|
return self.exception_handler.handle(e)
|
||||||
|
logging.error(str(e))
|
||||||
|
return
|
||||||
|
if middlewares:
|
||||||
|
for middleware in middlewares:
|
||||||
|
middleware.post_process(message, data, handler_error)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _notify_command_handlers(self, handlers, new_messages, update_type):
|
||||||
"""
|
"""
|
||||||
Notifies command handlers
|
Notifies command handlers
|
||||||
:param handlers:
|
:param handlers:
|
||||||
:param new_messages:
|
:param new_messages:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if len(handlers) == 0:
|
if len(handlers) == 0 and not self.use_class_middlewares:
|
||||||
return
|
return
|
||||||
|
|
||||||
for message in new_messages:
|
for message in new_messages:
|
||||||
|
middleware = self._check_middleware(update_type)
|
||||||
|
if self.use_class_middlewares and middleware:
|
||||||
|
self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware)
|
||||||
|
return
|
||||||
for message_handler in handlers:
|
for message_handler in handlers:
|
||||||
if self._test_message_handler(message_handler, message):
|
if self._test_message_handler(message_handler, message):
|
||||||
self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler')
|
self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler')
|
||||||
|
@ -166,3 +166,42 @@ class StatesGroup:
|
|||||||
value.name = ':'.join((cls.__name__, name))
|
value.name = ':'.join((cls.__name__, name))
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMiddleware:
|
||||||
|
"""
|
||||||
|
Base class for middleware.
|
||||||
|
Your middlewares should be inherited from this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_process(self, message, data):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def post_process(self, message, data, exception):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SkipHandler:
|
||||||
|
"""
|
||||||
|
Class for skipping handlers.
|
||||||
|
Just return instance of this class
|
||||||
|
in middleware to skip handler.
|
||||||
|
Update will go to post_process,
|
||||||
|
but will skip execution of handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CancelUpdate:
|
||||||
|
"""
|
||||||
|
Class for canceling updates.
|
||||||
|
Just return instance of this class
|
||||||
|
in middleware to skip update.
|
||||||
|
Update will skip handler and execution
|
||||||
|
of post_process in middlewares.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user