1
0
mirror of https://github.com/eternnoir/pyTelegramBotAPI.git synced 2023-08-10 21:12:57 +03:00

Middleware support

This commit is contained in:
_run 2021-11-27 19:04:03 +05:00
parent d7b0513fb1
commit 6770011dd7
3 changed files with 138 additions and 307 deletions

View File

@ -13,6 +13,9 @@ from typing import Any, Callable, List, Optional, Union
import telebot.util
import telebot.types
from inspect import signature
logger = logging.getLogger('TeleBot')
formatter = logging.Formatter(
@ -69,6 +72,30 @@ class ExceptionHandler:
return False
class SkipHandler:
"""
Class for skipping handlers.
Just return instance of this class
in middleware to skip handler.
Update will go to post_process,
but will skip execution of handler.
"""
def __init__(self) -> None:
pass
class CancelUpdate:
"""
Class for canceling updates.
Just return instance of this class
in middleware to skip update.
Update will skip handler and execution
of post_process in middlewares.
"""
def __init__(self) -> None:
pass
class TeleBot:
""" This is TeleBot Class
Methods:
@ -3351,33 +3378,16 @@ class AsyncTeleBot:
self.current_states = asyncio_handler_backends.StateMemory()
if asyncio_helper.ENABLE_MIDDLEWARE:
self.typed_middleware_handlers = {
'message': [],
'edited_message': [],
'channel_post': [],
'edited_channel_post': [],
'inline_query': [],
'chosen_inline_result': [],
'callback_query': [],
'shipping_query': [],
'pre_checkout_query': [],
'poll': [],
'poll_answer': [],
'my_chat_member': [],
'chat_member': [],
'chat_join_request': []
}
self.default_middleware_handlers = []
self.middlewares = []
async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None,
timeout: Optional[int]=None, allowed_updates: Optional[List]=None) -> types.Update:
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates)
timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> types.Update:
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout)
return [types.Update.de_json(ju) for ju in json_updates]
async def polling(self, non_stop: bool=False, skip_pending=False, interval: int=0, timeout: int=20,
long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None,
request_timeout: int=20, allowed_updates: Optional[List[str]]=None,
none_stop: Optional[bool]=None):
"""
This allows the bot to retrieve Updates automatically and notify listeners and message handlers accordingly.
@ -3389,7 +3399,7 @@ class AsyncTeleBot:
:param non_stop: Do not stop polling when an ApiException occurs.
:param timeout: Request connection timeout
:param skip_pending: skip old updates
:param long_polling_timeout: Timeout in seconds for long polling (see API docs)
:param request_timeout: Timeout in seconds for a request.
:param allowed_updates: A list of the update types you want your bot to receive.
For example, specify [message, edited_channel_post, callback_query] to only receive updates of these types.
See util.update_types for a complete list of available update types.
@ -3406,9 +3416,9 @@ class AsyncTeleBot:
if skip_pending:
await self.skip_updates()
await self._process_polling(non_stop, interval, timeout, long_polling_timeout, allowed_updates)
await self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)
async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, long_polling_timeout: int=20, logger_level=logging.ERROR,
async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, request_timeout: int=20, logger_level=logging.ERROR,
allowed_updates: Optional[List[str]]=None, *args, **kwargs):
"""
Wrap polling with infinite loop and exception handling to avoid bot stops polling.
@ -3432,7 +3442,7 @@ class AsyncTeleBot:
self._polling = True
while self._polling:
try:
await self._process_polling(non_stop=True, timeout=timeout, long_polling_timeout=long_polling_timeout,
await self._process_polling(non_stop=True, timeout=timeout, request_timeout=request_timeout,
allowed_updates=allowed_updates, *args, **kwargs)
except Exception as e:
if logger_level and logger_level >= logging.ERROR:
@ -3447,13 +3457,13 @@ class AsyncTeleBot:
logger.error("Break infinity polling")
async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None):
request_timeout: int=20, allowed_updates: Optional[List[str]]=None):
"""
Function to process polling.
:param non_stop: Do not stop polling when an ApiException occurs.
:param interval: Delay between two update retrivals
:param timeout: Request connection timeout
:param long_polling_timeout: Timeout in seconds for long polling (see API docs)
:param request_timeout: Timeout in seconds for long polling (see API docs)
:param allowed_updates: A list of the update types you want your bot to receive.
For example, specify [message, edited_channel_post, callback_query] to only receive updates of these types.
See util.update_types for a complete list of available update types.
@ -3471,49 +3481,74 @@ class AsyncTeleBot:
while self._polling:
try:
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout)
if updates:
logger.debug(f"Received {len(updates)} updates.")
await self.process_new_updates(updates)
if interval: await asyncio.sleep(interval)
except KeyboardInterrupt:
logger.info("KeyboardInterrupt received.")
break
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
except asyncio.CancelledError:
break
return
except asyncio_helper.ApiTelegramException as e:
logger.info(str(e))
logger.error(str(e))
continue
except Exception as e:
logger.error('Cause exception while getting updates.')
logger.error(str(e))
await asyncio.sleep(3)
continue
if non_stop:
logger.error(str(e))
await asyncio.sleep(3)
continue
else:
raise e
if updates:
self.offset = updates[-1].update_id + 1
self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates
if interval: await asyncio.sleep(interval)
finally:
self._polling = False
logger.warning('Polling is stopped.')
async def _loop_create_task(self, coro):
def _loop_create_task(self, coro):
return asyncio.create_task(coro)
async def _process_updates(self, handlers, messages):
async def _process_updates(self, handlers, messages, update_type):
"""
Process updates.
:param handlers:
:param messages:
:return:
"""
for message in messages:
for message_handler in handlers:
process_update = await self._test_message_handler(message_handler, message)
if not process_update:
continue
elif process_update:
try:
await self._loop_create_task(message_handler['function'](message))
break
except Exception as e:
logger.error(str(e))
middleware = await self.process_middlewares(message, update_type)
self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware))
async def _run_middlewares_and_handlers(self, handlers, message, middleware):
handler_error = None
data = {}
for message_handler in handlers:
process_update = await self._test_message_handler(message_handler, message)
if not process_update:
continue
elif process_update:
if middleware:
middleware_result = await middleware.pre_process(message, data)
if isinstance(middleware_result, SkipHandler):
await middleware.post_process(message, data, handler_error)
break
if isinstance(middleware_result, CancelUpdate):
return
try:
if "data" in signature(message_handler['function']).parameters:
await message_handler['function'](message, data)
else:
await message_handler['function'](message)
break
except Exception as e:
handler_error = e
logger.info(str(e))
if middleware:
await middleware.post_process(message, data, handler_error)
# update handling
async def process_new_updates(self, updates):
upd_count = len(updates)
@ -3535,21 +3570,8 @@ class AsyncTeleBot:
new_chat_members = None
chat_join_request = None
for update in updates:
if asyncio_helper.ENABLE_MIDDLEWARE:
try:
self.process_middlewares(update)
except Exception as e:
logger.error(str(e))
if not self.suppress_middleware_excepions:
raise
else:
if update.update_id > self.offset: self.offset = update.update_id
continue
logger.debug('Processing updates: {0}'.format(update))
if update.update_id:
self.offset = update.update_id + 1
if update.message:
logger.info('Processing message')
if new_messages is None: new_messages = []
new_messages.append(update.message)
if update.edited_message:
@ -3620,69 +3642,55 @@ class AsyncTeleBot:
await self.process_new_chat_member(new_chat_members)
if chat_join_request:
await self.process_chat_join_request(chat_join_request)
async def process_new_messages(self, new_messages):
await self.__notify_update(new_messages)
await self._process_updates(self.message_handlers, new_messages)
await self._process_updates(self.message_handlers, new_messages, 'message')
async def process_new_edited_messages(self, edited_message):
await self._process_updates(self.edited_message_handlers, edited_message)
await self._process_updates(self.edited_message_handlers, edited_message, 'edited_message')
async def process_new_channel_posts(self, channel_post):
await self._process_updates(self.channel_post_handlers, channel_post)
await self._process_updates(self.channel_post_handlers, channel_post , 'channel_post')
async def process_new_edited_channel_posts(self, edited_channel_post):
await self._process_updates(self.edited_channel_post_handlers, edited_channel_post)
await self._process_updates(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post')
async def process_new_inline_query(self, new_inline_querys):
await self._process_updates(self.inline_handlers, new_inline_querys)
await self._process_updates(self.inline_handlers, new_inline_querys, 'inline_query')
async def process_new_chosen_inline_query(self, new_chosen_inline_querys):
await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys)
await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query')
async def process_new_callback_query(self, new_callback_querys):
await self._process_updates(self.callback_query_handlers, new_callback_querys)
await self._process_updates(self.callback_query_handlers, new_callback_querys, 'callback_query')
async def process_new_shipping_query(self, new_shipping_querys):
await self._process_updates(self.shipping_query_handlers, new_shipping_querys)
await self._process_updates(self.shipping_query_handlers, new_shipping_querys, 'shipping_query')
async def process_new_pre_checkout_query(self, pre_checkout_querys):
await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys)
await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query')
async def process_new_poll(self, polls):
await self._process_updates(self.poll_handlers, polls)
await self._process_updates(self.poll_handlers, polls, 'poll')
async def process_new_poll_answer(self, poll_answers):
await self._process_updates(self.poll_answer_handlers, poll_answers)
await self._process_updates(self.poll_answer_handlers, poll_answers, 'poll_answer')
async def process_new_my_chat_member(self, my_chat_members):
await self._process_updates(self.my_chat_member_handlers, my_chat_members)
await self._process_updates(self.my_chat_member_handlers, my_chat_members, 'my_chat_member')
async def process_new_chat_member(self, chat_members):
await self._process_updates(self.chat_member_handlers, chat_members)
await self._process_updates(self.chat_member_handlers, chat_members, 'chat_member')
async def process_chat_join_request(self, chat_join_request):
await self._process_updates(self.chat_join_request_handlers, chat_join_request)
async def process_middlewares(self, update):
for update_type, middlewares in self.typed_middleware_handlers.items():
if getattr(update, update_type) is not None:
for typed_middleware_handler in middlewares:
try:
typed_middleware_handler(self, getattr(update, update_type))
except Exception as e:
e.args = e.args + (f'Typed middleware handler "{typed_middleware_handler.__qualname__}"',)
raise
if len(self.default_middleware_handlers) > 0:
for default_middleware_handler in self.default_middleware_handlers:
try:
default_middleware_handler(self, update)
except Exception as e:
e.args = e.args + (f'Default middleware handler "{default_middleware_handler.__qualname__}"',)
raise
await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
async def process_middlewares(self, update, update_type):
for middleware in self.middlewares:
if update_type in middleware.update_types:
return middleware
return None
async def __notify_update(self, new_messages):
if len(self.update_listener) == 0:
@ -3759,56 +3767,16 @@ class AsyncTeleBot:
elif isinstance(filter_check, asyncio_filters.AdvancedCustomFilter):
return await filter_check.check(message, filter_value)
else:
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter!")
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.")
return False
def middleware_handler(self, update_types=None):
def setup_middleware(self, middleware):
"""
Middleware handler decorator.
This decorator can be used to decorate functions that must be handled as middlewares before entering any other
message handlers
But, be careful and check type of the update inside the handler if more than one update_type is given
Example:
bot = TeleBot('TOKEN')
# Print post message text before entering to any post_channel handlers
@bot.middleware_handler(update_types=['channel_post', 'edited_channel_post'])
def print_channel_post_text(bot_instance, channel_post):
print(channel_post.text)
# Print update id before entering to any handlers
@bot.middleware_handler()
def print_channel_post_text(bot_instance, update):
print(update.update_id)
:param update_types: Optional list of update types that can be passed into the middleware handler.
"""
def decorator(handler):
self.add_middleware_handler(handler, update_types)
return handler
return decorator
def add_middleware_handler(self, handler, update_types=None):
"""
Add middleware handler
:param handler:
:param update_types:
Setup middleware
:param middleware:
:return:
"""
if not asyncio_helper.ENABLE_MIDDLEWARE:
raise RuntimeError("Middleware is not enabled. Use asyncio_helper.ENABLE_MIDDLEWARE.")
if update_types:
for update_type in update_types:
self.typed_middleware_handlers[update_type].append(handler)
else:
self.default_middleware_handlers.append(handler)
self.middlewares.append(middleware)
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
"""

View File

@ -1,146 +1,6 @@
import os
import pickle
import threading
from telebot import apihelper
class HandlerBackend(object):
"""
Class for saving (next step|reply) handlers
"""
def __init__(self, handlers=None):
if handlers is None:
handlers = {}
self.handlers = handlers
async def register_handler(self, handler_group_id, handler):
raise NotImplementedError()
async def clear_handlers(self, handler_group_id):
raise NotImplementedError()
async def get_handlers(self, handler_group_id):
raise NotImplementedError()
class MemoryHandlerBackend(HandlerBackend):
async def register_handler(self, handler_group_id, handler):
if handler_group_id in self.handlers:
self.handlers[handler_group_id].append(handler)
else:
self.handlers[handler_group_id] = [handler]
async def clear_handlers(self, handler_group_id):
self.handlers.pop(handler_group_id, None)
async def get_handlers(self, handler_group_id):
return self.handlers.pop(handler_group_id, None)
async def load_handlers(self, filename, del_file_after_loading):
raise NotImplementedError()
class FileHandlerBackend(HandlerBackend):
def __init__(self, handlers=None, filename='./.handler-saves/handlers.save', delay=120):
super(FileHandlerBackend, self).__init__(handlers)
self.filename = filename
self.delay = delay
self.timer = threading.Timer(delay, self.save_handlers)
async def register_handler(self, handler_group_id, handler):
if handler_group_id in self.handlers:
self.handlers[handler_group_id].append(handler)
else:
self.handlers[handler_group_id] = [handler]
await self.start_save_timer()
async def clear_handlers(self, handler_group_id):
self.handlers.pop(handler_group_id, None)
await self.start_save_timer()
async def get_handlers(self, handler_group_id):
handlers = self.handlers.pop(handler_group_id, None)
await self.start_save_timer()
return handlers
async def start_save_timer(self):
if not self.timer.is_alive():
if self.delay <= 0:
self.save_handlers()
else:
self.timer = threading.Timer(self.delay, self.save_handlers)
self.timer.start()
async def save_handlers(self):
await self.dump_handlers(self.handlers, self.filename)
async def load_handlers(self, filename=None, del_file_after_loading=True):
if not filename:
filename = self.filename
tmp = await self.return_load_handlers(filename, del_file_after_loading=del_file_after_loading)
if tmp is not None:
self.handlers.update(tmp)
@staticmethod
async def dump_handlers(handlers, filename, file_mode="wb"):
dirs = filename.rsplit('/', maxsplit=1)[0]
os.makedirs(dirs, exist_ok=True)
with open(filename + ".tmp", file_mode) as file:
if (apihelper.CUSTOM_SERIALIZER is None):
pickle.dump(handlers, file)
else:
apihelper.CUSTOM_SERIALIZER.dump(handlers, file)
if os.path.isfile(filename):
os.remove(filename)
os.rename(filename + ".tmp", filename)
@staticmethod
async def return_load_handlers(filename, del_file_after_loading=True):
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
with open(filename, "rb") as file:
if (apihelper.CUSTOM_SERIALIZER is None):
handlers = pickle.load(file)
else:
handlers = apihelper.CUSTOM_SERIALIZER.load(file)
if del_file_after_loading:
os.remove(filename)
return handlers
class RedisHandlerBackend(HandlerBackend):
def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None):
super(RedisHandlerBackend, self).__init__(handlers)
from redis import Redis
self.prefix = prefix
self.redis = Redis(host, port, db, password)
async def _key(self, handle_group_id):
return ':'.join((self.prefix, str(handle_group_id)))
async def register_handler(self, handler_group_id, handler):
handlers = []
value = self.redis.get(self._key(handler_group_id))
if value:
handlers = pickle.loads(value)
handlers.append(handler)
self.redis.set(self._key(handler_group_id), pickle.dumps(handlers))
async def clear_handlers(self, handler_group_id):
self.redis.delete(self._key(handler_group_id))
async def get_handlers(self, handler_group_id):
handlers = None
value = self.redis.get(self._key(handler_group_id))
if value:
handlers = pickle.loads(value)
self.clear_handlers(handler_group_id)
return handlers
class StateMemory:
@ -341,3 +201,19 @@ class StateFileContext:
await self.obj._save_data(old_data)
return
class BaseMiddleware:
"""
Base class for middleware.
Your middlewares should be inherited from this class.
"""
def __init__(self):
pass
async def pre_process(self, message, data):
raise NotImplementedError
async def post_process(self, message, data, exception):
raise NotImplementedError

View File

@ -1,28 +1,19 @@
import asyncio
import asyncio # for future uses
from time import time
import aiohttp
from telebot import types
import json
import logging
try:
import ujson as json
except ImportError:
import json
import requests
from requests.exceptions import HTTPError, ConnectionError, Timeout
try:
# noinspection PyUnresolvedReferences
from requests.packages.urllib3 import fields
format_header_param = fields.format_header_param
except ImportError:
format_header_param = None
API_URL = 'https://api.telegram.org/bot{0}/{1}'
from datetime import datetime
import telebot
from telebot import util
class SessionBase:
@ -44,19 +35,16 @@ READ_TIMEOUT = 30
LONG_POLLING_TIMEOUT = 10 # Should be positive, short polling should be used for testing purposes only (https://core.telegram.org/bots/api#getupdates)
logger = telebot.logger
RETRY_ON_ERROR = False
RETRY_TIMEOUT = 2
MAX_RETRIES = 15
CUSTOM_SERIALIZER = None
CUSTOM_REQUEST_SENDER = None
ENABLE_MIDDLEWARE = False
async def _process_request(token, url, method='get', params=None, files=None):
async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None):
async with await session_manager._get_new_session() as session:
async with session.get(API_URL.format(token, url), params=params, data=files) as response:
async with session.get(API_URL.format(token, url), params=params, data=files, timeout=request_timeout) as response:
logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout).replace(token, token.split(':')[0] + ":{TOKEN}"))
json_result = await _check_result(url, response)
if json_result:
return json_result['result']
@ -155,7 +143,7 @@ async def get_webhook_info(token, timeout=None):
async def get_updates(token, offset=None, limit=None,
timeout=None, allowed_updates=None, long_polling_timeout=None):
timeout=None, allowed_updates=None, request_timeout=None):
method_name = 'getUpdates'
params = {}
if offset:
@ -166,8 +154,7 @@ async def get_updates(token, offset=None, limit=None,
params['timeout'] = timeout
elif allowed_updates:
params['allowed_updates'] = allowed_updates
params['long_polling_timeout'] = long_polling_timeout if long_polling_timeout else LONG_POLLING_TIMEOUT
return await _process_request(token, method_name, params=params)
return await _process_request(token, method_name, params=params, request_timeout=request_timeout)
async def _check_result(method_name, result):
"""