mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
Added middlewares.
Bumped middlewares
This commit is contained in:
parent
ac12d0fc02
commit
388477686b
@ -16,12 +16,16 @@ import telebot.types
|
||||
# storage
|
||||
from telebot.storage import StatePickleStorage, StateMemoryStorage
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger('TeleBot')
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"'
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
console_output_handler = logging.StreamHandler(sys.stderr)
|
||||
console_output_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_output_handler)
|
||||
@ -29,7 +33,7 @@ logger.addHandler(console_output_handler)
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
from telebot import apihelper, util, types
|
||||
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend
|
||||
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, BaseMiddleware, CancelUpdate, SkipHandler
|
||||
from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter
|
||||
|
||||
|
||||
@ -147,7 +151,7 @@ class TeleBot:
|
||||
def __init__(
|
||||
self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2,
|
||||
next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0,
|
||||
suppress_middleware_excepions=False, state_storage=StateMemoryStorage()
|
||||
suppress_middleware_excepions=False, state_storage=StateMemoryStorage(), use_class_middlewares=False
|
||||
):
|
||||
"""
|
||||
:param token: bot API token
|
||||
@ -193,7 +197,8 @@ class TeleBot:
|
||||
|
||||
self.current_states = state_storage
|
||||
|
||||
if apihelper.ENABLE_MIDDLEWARE:
|
||||
self.use_class_middlewares = use_class_middlewares
|
||||
if apihelper.ENABLE_MIDDLEWARE and not use_class_middlewares:
|
||||
self.typed_middleware_handlers = {
|
||||
'message': [],
|
||||
'edited_message': [],
|
||||
@ -211,6 +216,13 @@ class TeleBot:
|
||||
'chat_join_request': []
|
||||
}
|
||||
self.default_middleware_handlers = []
|
||||
if apihelper.ENABLE_MIDDLEWARE and use_class_middlewares:
|
||||
logger.warning(
|
||||
'You are using class based middlewares, but you have '
|
||||
'ENABLE_MIDDLEWARE set to True. This is not recommended.'
|
||||
)
|
||||
self.middlewares = [] if use_class_middlewares else None
|
||||
|
||||
|
||||
self.threaded = threaded
|
||||
if self.threaded:
|
||||
@ -440,6 +452,7 @@ class TeleBot:
|
||||
else:
|
||||
if update.update_id > self.last_update_id: self.last_update_id = update.update_id
|
||||
continue
|
||||
|
||||
|
||||
if update.update_id > self.last_update_id:
|
||||
self.last_update_id = update.update_id
|
||||
@ -519,46 +532,46 @@ class TeleBot:
|
||||
self._notify_next_handlers(new_messages)
|
||||
self._notify_reply_handlers(new_messages)
|
||||
self.__notify_update(new_messages)
|
||||
self._notify_command_handlers(self.message_handlers, new_messages)
|
||||
self._notify_command_handlers(self.message_handlers, new_messages, 'message')
|
||||
|
||||
def process_new_edited_messages(self, edited_message):
|
||||
self._notify_command_handlers(self.edited_message_handlers, edited_message)
|
||||
self._notify_command_handlers(self.edited_message_handlers, edited_message, 'edited_message')
|
||||
|
||||
def process_new_channel_posts(self, channel_post):
|
||||
self._notify_command_handlers(self.channel_post_handlers, channel_post)
|
||||
self._notify_command_handlers(self.channel_post_handlers, channel_post, 'channel_post')
|
||||
|
||||
def process_new_edited_channel_posts(self, edited_channel_post):
|
||||
self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post)
|
||||
self._notify_command_handlers(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post')
|
||||
|
||||
def process_new_inline_query(self, new_inline_querys):
|
||||
self._notify_command_handlers(self.inline_handlers, new_inline_querys)
|
||||
self._notify_command_handlers(self.inline_handlers, new_inline_querys, 'inline_query')
|
||||
|
||||
def process_new_chosen_inline_query(self, new_chosen_inline_querys):
|
||||
self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys)
|
||||
self._notify_command_handlers(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query')
|
||||
|
||||
def process_new_callback_query(self, new_callback_querys):
|
||||
self._notify_command_handlers(self.callback_query_handlers, new_callback_querys)
|
||||
self._notify_command_handlers(self.callback_query_handlers, new_callback_querys, 'callback_query')
|
||||
|
||||
def process_new_shipping_query(self, new_shipping_querys):
|
||||
self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys)
|
||||
self._notify_command_handlers(self.shipping_query_handlers, new_shipping_querys, 'shipping_query')
|
||||
|
||||
def process_new_pre_checkout_query(self, pre_checkout_querys):
|
||||
self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys)
|
||||
self._notify_command_handlers(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query')
|
||||
|
||||
def process_new_poll(self, polls):
|
||||
self._notify_command_handlers(self.poll_handlers, polls)
|
||||
self._notify_command_handlers(self.poll_handlers, polls, 'poll')
|
||||
|
||||
def process_new_poll_answer(self, poll_answers):
|
||||
self._notify_command_handlers(self.poll_answer_handlers, poll_answers)
|
||||
self._notify_command_handlers(self.poll_answer_handlers, poll_answers, 'poll_answer')
|
||||
|
||||
def process_new_my_chat_member(self, my_chat_members):
|
||||
self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members)
|
||||
self._notify_command_handlers(self.my_chat_member_handlers, my_chat_members, 'my_chat_member')
|
||||
|
||||
def process_new_chat_member(self, chat_members):
|
||||
self._notify_command_handlers(self.chat_member_handlers, chat_members)
|
||||
self._notify_command_handlers(self.chat_member_handlers, chat_members, 'chat_member')
|
||||
|
||||
def process_new_chat_join_request(self, chat_join_request):
|
||||
self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request)
|
||||
self._notify_command_handlers(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
|
||||
|
||||
def process_middlewares(self, update):
|
||||
for update_type, middlewares in self.typed_middleware_handlers.items():
|
||||
@ -2535,6 +2548,20 @@ class TeleBot:
|
||||
chat_id = message.chat.id
|
||||
self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs)
|
||||
|
||||
|
||||
def setup_middleware(self, middleware: BaseMiddleware):
|
||||
"""
|
||||
Register middleware
|
||||
:param middleware: Subclass of `telebot.handler_backends.BaseMiddleware`
|
||||
:return: None
|
||||
"""
|
||||
if not self.use_class_middlewares:
|
||||
logger.warning('Middleware is not enabled. Pass use_class_middlewares=True to enable it.')
|
||||
return
|
||||
self.middlewares.append(middleware)
|
||||
|
||||
|
||||
|
||||
def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None:
|
||||
"""
|
||||
Sets a new state of a user.
|
||||
@ -3500,16 +3527,95 @@ class TeleBot:
|
||||
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.")
|
||||
return False
|
||||
|
||||
def _notify_command_handlers(self, handlers, new_messages):
|
||||
# middleware check-up method
|
||||
def _check_middleware(self, update_type):
|
||||
"""
|
||||
Check middleware
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
if self.middlewares: middlewares = [i for i in self.middlewares if update_type in i.update_types]
|
||||
if not middlewares: return
|
||||
return middlewares
|
||||
|
||||
def _run_middlewares_and_handler(self, message, handlers, middlewares, *args, **kwargs):
|
||||
"""This class is made to run handler and middleware in queue.
|
||||
:param handler: handler that should be executed.
|
||||
:param middleware: middleware that should be executed.
|
||||
:return:
|
||||
"""
|
||||
data = {}
|
||||
params =[]
|
||||
handler_error = None
|
||||
skip_handler = False
|
||||
if middlewares:
|
||||
for middleware in middlewares:
|
||||
result = middleware.pre_process(message, data)
|
||||
# We will break this loop if CancelUpdate is returned
|
||||
# Also, we will not run other middlewares
|
||||
if isinstance(result, CancelUpdate):
|
||||
return
|
||||
elif isinstance(result, SkipHandler) and skip_handler is False:
|
||||
skip_handler = True
|
||||
|
||||
|
||||
|
||||
|
||||
try:
|
||||
if handlers and not skip_handler:
|
||||
for handler in handlers:
|
||||
process_handler = self._test_message_handler(handler, message)
|
||||
if not process_handler: continue
|
||||
else:
|
||||
for i in inspect.signature(handler['function']).parameters:
|
||||
params.append(i)
|
||||
if len(params) == 1:
|
||||
handler['function'](message)
|
||||
|
||||
elif len(params) == 2:
|
||||
if handler.get('pass_bot') is True:
|
||||
handler['function'](message, self)
|
||||
|
||||
elif handler.get('pass_bot') is False:
|
||||
handler['function'](message, data)
|
||||
|
||||
elif len(params) == 3:
|
||||
if params[2] == 'bot' and handler.get('pass_bot') is True:
|
||||
handler['function'](message, data, self)
|
||||
|
||||
else:
|
||||
handler['function'](message, self, data)
|
||||
|
||||
except Exception as e:
|
||||
handler_error = e
|
||||
|
||||
if not middlewares:
|
||||
if self.exception_handler:
|
||||
return self.exception_handler.handle(e)
|
||||
logging.error(str(e))
|
||||
return
|
||||
if middlewares:
|
||||
for middleware in middlewares:
|
||||
middleware.post_process(message, data, handler_error)
|
||||
|
||||
|
||||
|
||||
|
||||
def _notify_command_handlers(self, handlers, new_messages, update_type):
|
||||
"""
|
||||
Notifies command handlers
|
||||
:param handlers:
|
||||
:param new_messages:
|
||||
:return:
|
||||
"""
|
||||
if len(handlers) == 0:
|
||||
if len(handlers) == 0 and not self.use_class_middlewares:
|
||||
return
|
||||
|
||||
for message in new_messages:
|
||||
middleware = self._check_middleware(update_type)
|
||||
if self.use_class_middlewares and middleware:
|
||||
self._exec_task(self._run_middlewares_and_handler, message, handlers=handlers, middlewares=middleware)
|
||||
return
|
||||
for message_handler in handlers:
|
||||
if self._test_message_handler(message_handler, message):
|
||||
self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler')
|
||||
|
@ -165,4 +165,43 @@ class StatesGroup:
|
||||
# change value of that variable
|
||||
value.name = ':'.join((cls.__name__, name))
|
||||
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
"""
|
||||
Base class for middleware.
|
||||
Your middlewares should be inherited from this class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def pre_process(self, message, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def post_process(self, message, data, exception):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SkipHandler:
|
||||
"""
|
||||
Class for skipping handlers.
|
||||
Just return instance of this class
|
||||
in middleware to skip handler.
|
||||
Update will go to post_process,
|
||||
but will skip execution of handler.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
class CancelUpdate:
|
||||
"""
|
||||
Class for canceling updates.
|
||||
Just return instance of this class
|
||||
in middleware to skip update.
|
||||
Update will skip handler and execution
|
||||
of post_process in middlewares.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
Loading…
Reference in New Issue
Block a user