mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
Middleware support
This commit is contained in:
parent
d7b0513fb1
commit
6770011dd7
@ -13,6 +13,9 @@ from typing import Any, Callable, List, Optional, Union
|
|||||||
import telebot.util
|
import telebot.util
|
||||||
import telebot.types
|
import telebot.types
|
||||||
|
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
logger = logging.getLogger('TeleBot')
|
logger = logging.getLogger('TeleBot')
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
@ -69,6 +72,30 @@ class ExceptionHandler:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class SkipHandler:
|
||||||
|
"""
|
||||||
|
Class for skipping handlers.
|
||||||
|
Just return instance of this class
|
||||||
|
in middleware to skip handler.
|
||||||
|
Update will go to post_process,
|
||||||
|
but will skip execution of handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CancelUpdate:
|
||||||
|
"""
|
||||||
|
Class for canceling updates.
|
||||||
|
Just return instance of this class
|
||||||
|
in middleware to skip update.
|
||||||
|
Update will skip handler and execution
|
||||||
|
of post_process in middlewares.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
class TeleBot:
|
class TeleBot:
|
||||||
""" This is TeleBot Class
|
""" This is TeleBot Class
|
||||||
Methods:
|
Methods:
|
||||||
@ -3351,33 +3378,16 @@ class AsyncTeleBot:
|
|||||||
self.current_states = asyncio_handler_backends.StateMemory()
|
self.current_states = asyncio_handler_backends.StateMemory()
|
||||||
|
|
||||||
|
|
||||||
if asyncio_helper.ENABLE_MIDDLEWARE:
|
self.middlewares = []
|
||||||
self.typed_middleware_handlers = {
|
|
||||||
'message': [],
|
|
||||||
'edited_message': [],
|
|
||||||
'channel_post': [],
|
|
||||||
'edited_channel_post': [],
|
|
||||||
'inline_query': [],
|
|
||||||
'chosen_inline_result': [],
|
|
||||||
'callback_query': [],
|
|
||||||
'shipping_query': [],
|
|
||||||
'pre_checkout_query': [],
|
|
||||||
'poll': [],
|
|
||||||
'poll_answer': [],
|
|
||||||
'my_chat_member': [],
|
|
||||||
'chat_member': [],
|
|
||||||
'chat_join_request': []
|
|
||||||
}
|
|
||||||
self.default_middleware_handlers = []
|
|
||||||
|
|
||||||
|
|
||||||
async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None,
|
async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None,
|
||||||
timeout: Optional[int]=None, allowed_updates: Optional[List]=None) -> types.Update:
|
timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> types.Update:
|
||||||
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates)
|
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout)
|
||||||
return [types.Update.de_json(ju) for ju in json_updates]
|
return [types.Update.de_json(ju) for ju in json_updates]
|
||||||
|
|
||||||
async def polling(self, non_stop: bool=False, skip_pending=False, interval: int=0, timeout: int=20,
|
async def polling(self, non_stop: bool=False, skip_pending=False, interval: int=0, timeout: int=20,
|
||||||
long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None,
|
request_timeout: int=20, allowed_updates: Optional[List[str]]=None,
|
||||||
none_stop: Optional[bool]=None):
|
none_stop: Optional[bool]=None):
|
||||||
"""
|
"""
|
||||||
This allows the bot to retrieve Updates automatically and notify listeners and message handlers accordingly.
|
This allows the bot to retrieve Updates automatically and notify listeners and message handlers accordingly.
|
||||||
@ -3389,7 +3399,7 @@ class AsyncTeleBot:
|
|||||||
:param non_stop: Do not stop polling when an ApiException occurs.
|
:param non_stop: Do not stop polling when an ApiException occurs.
|
||||||
:param timeout: Request connection timeout
|
:param timeout: Request connection timeout
|
||||||
:param skip_pending: skip old updates
|
:param skip_pending: skip old updates
|
||||||
:param long_polling_timeout: Timeout in seconds for long polling (see API docs)
|
:param request_timeout: Timeout in seconds for a request.
|
||||||
:param allowed_updates: A list of the update types you want your bot to receive.
|
:param allowed_updates: A list of the update types you want your bot to receive.
|
||||||
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
|
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
|
||||||
See util.update_types for a complete list of available update types.
|
See util.update_types for a complete list of available update types.
|
||||||
@ -3406,9 +3416,9 @@ class AsyncTeleBot:
|
|||||||
|
|
||||||
if skip_pending:
|
if skip_pending:
|
||||||
await self.skip_updates()
|
await self.skip_updates()
|
||||||
await self._process_polling(non_stop, interval, timeout, long_polling_timeout, allowed_updates)
|
await self._process_polling(non_stop, interval, timeout, request_timeout, allowed_updates)
|
||||||
|
|
||||||
async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, long_polling_timeout: int=20, logger_level=logging.ERROR,
|
async def infinity_polling(self, timeout: int=20, skip_pending: bool=False, request_timeout: int=20, logger_level=logging.ERROR,
|
||||||
allowed_updates: Optional[List[str]]=None, *args, **kwargs):
|
allowed_updates: Optional[List[str]]=None, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Wrap polling with infinite loop and exception handling to avoid bot stops polling.
|
Wrap polling with infinite loop and exception handling to avoid bot stops polling.
|
||||||
@ -3432,7 +3442,7 @@ class AsyncTeleBot:
|
|||||||
self._polling = True
|
self._polling = True
|
||||||
while self._polling:
|
while self._polling:
|
||||||
try:
|
try:
|
||||||
await self._process_polling(non_stop=True, timeout=timeout, long_polling_timeout=long_polling_timeout,
|
await self._process_polling(non_stop=True, timeout=timeout, request_timeout=request_timeout,
|
||||||
allowed_updates=allowed_updates, *args, **kwargs)
|
allowed_updates=allowed_updates, *args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if logger_level and logger_level >= logging.ERROR:
|
if logger_level and logger_level >= logging.ERROR:
|
||||||
@ -3447,13 +3457,13 @@ class AsyncTeleBot:
|
|||||||
logger.error("Break infinity polling")
|
logger.error("Break infinity polling")
|
||||||
|
|
||||||
async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
|
async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
|
||||||
long_polling_timeout: int=20, allowed_updates: Optional[List[str]]=None):
|
request_timeout: int=20, allowed_updates: Optional[List[str]]=None):
|
||||||
"""
|
"""
|
||||||
Function to process polling.
|
Function to process polling.
|
||||||
:param non_stop: Do not stop polling when an ApiException occurs.
|
:param non_stop: Do not stop polling when an ApiException occurs.
|
||||||
:param interval: Delay between two update retrivals
|
:param interval: Delay between two update retrivals
|
||||||
:param timeout: Request connection timeout
|
:param timeout: Request connection timeout
|
||||||
:param long_polling_timeout: Timeout in seconds for long polling (see API docs)
|
:param request_timeout: Timeout in seconds for long polling (see API docs)
|
||||||
:param allowed_updates: A list of the update types you want your bot to receive.
|
:param allowed_updates: A list of the update types you want your bot to receive.
|
||||||
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
|
For example, specify [“message”, “edited_channel_post”, “callback_query”] to only receive updates of these types.
|
||||||
See util.update_types for a complete list of available update types.
|
See util.update_types for a complete list of available update types.
|
||||||
@ -3471,49 +3481,74 @@ class AsyncTeleBot:
|
|||||||
while self._polling:
|
while self._polling:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout)
|
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
|
||||||
|
|
||||||
if updates:
|
|
||||||
logger.debug(f"Received {len(updates)} updates.")
|
|
||||||
|
|
||||||
await self.process_new_updates(updates)
|
|
||||||
if interval: await asyncio.sleep(interval)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("KeyboardInterrupt received.")
|
|
||||||
break
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
return
|
||||||
|
|
||||||
except asyncio_helper.ApiTelegramException as e:
|
except asyncio_helper.ApiTelegramException as e:
|
||||||
logger.info(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('Cause exception while getting updates.')
|
logger.error('Cause exception while getting updates.')
|
||||||
logger.error(str(e))
|
if non_stop:
|
||||||
await asyncio.sleep(3)
|
logger.error(str(e))
|
||||||
continue
|
await asyncio.sleep(3)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
if updates:
|
||||||
|
self.offset = updates[-1].update_id + 1
|
||||||
|
self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates
|
||||||
|
if interval: await asyncio.sleep(interval)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._polling = False
|
self._polling = False
|
||||||
logger.warning('Polling is stopped.')
|
logger.warning('Polling is stopped.')
|
||||||
|
|
||||||
|
|
||||||
async def _loop_create_task(self, coro):
|
def _loop_create_task(self, coro):
|
||||||
return asyncio.create_task(coro)
|
return asyncio.create_task(coro)
|
||||||
|
|
||||||
async def _process_updates(self, handlers, messages):
|
async def _process_updates(self, handlers, messages, update_type):
|
||||||
|
"""
|
||||||
|
Process updates.
|
||||||
|
:param handlers:
|
||||||
|
:param messages:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
for message_handler in handlers:
|
middleware = await self.process_middlewares(message, update_type)
|
||||||
process_update = await self._test_message_handler(message_handler, message)
|
self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware))
|
||||||
if not process_update:
|
|
||||||
continue
|
|
||||||
elif process_update:
|
|
||||||
try:
|
|
||||||
await self._loop_create_task(message_handler['function'](message))
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_middlewares_and_handlers(self, handlers, message, middleware):
|
||||||
|
handler_error = None
|
||||||
|
data = {}
|
||||||
|
for message_handler in handlers:
|
||||||
|
process_update = await self._test_message_handler(message_handler, message)
|
||||||
|
if not process_update:
|
||||||
|
continue
|
||||||
|
elif process_update:
|
||||||
|
if middleware:
|
||||||
|
middleware_result = await middleware.pre_process(message, data)
|
||||||
|
if isinstance(middleware_result, SkipHandler):
|
||||||
|
await middleware.post_process(message, data, handler_error)
|
||||||
|
break
|
||||||
|
if isinstance(middleware_result, CancelUpdate):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if "data" in signature(message_handler['function']).parameters:
|
||||||
|
await message_handler['function'](message, data)
|
||||||
|
else:
|
||||||
|
await message_handler['function'](message)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
handler_error = e
|
||||||
|
logger.info(str(e))
|
||||||
|
|
||||||
|
if middleware:
|
||||||
|
await middleware.post_process(message, data, handler_error)
|
||||||
# update handling
|
# update handling
|
||||||
async def process_new_updates(self, updates):
|
async def process_new_updates(self, updates):
|
||||||
upd_count = len(updates)
|
upd_count = len(updates)
|
||||||
@ -3535,21 +3570,8 @@ class AsyncTeleBot:
|
|||||||
new_chat_members = None
|
new_chat_members = None
|
||||||
chat_join_request = None
|
chat_join_request = None
|
||||||
for update in updates:
|
for update in updates:
|
||||||
if asyncio_helper.ENABLE_MIDDLEWARE:
|
|
||||||
try:
|
|
||||||
self.process_middlewares(update)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
if not self.suppress_middleware_excepions:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
if update.update_id > self.offset: self.offset = update.update_id
|
|
||||||
continue
|
|
||||||
logger.debug('Processing updates: {0}'.format(update))
|
logger.debug('Processing updates: {0}'.format(update))
|
||||||
if update.update_id:
|
|
||||||
self.offset = update.update_id + 1
|
|
||||||
if update.message:
|
if update.message:
|
||||||
logger.info('Processing message')
|
|
||||||
if new_messages is None: new_messages = []
|
if new_messages is None: new_messages = []
|
||||||
new_messages.append(update.message)
|
new_messages.append(update.message)
|
||||||
if update.edited_message:
|
if update.edited_message:
|
||||||
@ -3621,68 +3643,54 @@ class AsyncTeleBot:
|
|||||||
if chat_join_request:
|
if chat_join_request:
|
||||||
await self.process_chat_join_request(chat_join_request)
|
await self.process_chat_join_request(chat_join_request)
|
||||||
|
|
||||||
|
|
||||||
async def process_new_messages(self, new_messages):
|
async def process_new_messages(self, new_messages):
|
||||||
await self.__notify_update(new_messages)
|
await self.__notify_update(new_messages)
|
||||||
await self._process_updates(self.message_handlers, new_messages)
|
await self._process_updates(self.message_handlers, new_messages, 'message')
|
||||||
|
|
||||||
async def process_new_edited_messages(self, edited_message):
|
async def process_new_edited_messages(self, edited_message):
|
||||||
await self._process_updates(self.edited_message_handlers, edited_message)
|
await self._process_updates(self.edited_message_handlers, edited_message, 'edited_message')
|
||||||
|
|
||||||
async def process_new_channel_posts(self, channel_post):
|
async def process_new_channel_posts(self, channel_post):
|
||||||
await self._process_updates(self.channel_post_handlers, channel_post)
|
await self._process_updates(self.channel_post_handlers, channel_post , 'channel_post')
|
||||||
|
|
||||||
async def process_new_edited_channel_posts(self, edited_channel_post):
|
async def process_new_edited_channel_posts(self, edited_channel_post):
|
||||||
await self._process_updates(self.edited_channel_post_handlers, edited_channel_post)
|
await self._process_updates(self.edited_channel_post_handlers, edited_channel_post, 'edited_channel_post')
|
||||||
|
|
||||||
async def process_new_inline_query(self, new_inline_querys):
|
async def process_new_inline_query(self, new_inline_querys):
|
||||||
await self._process_updates(self.inline_handlers, new_inline_querys)
|
await self._process_updates(self.inline_handlers, new_inline_querys, 'inline_query')
|
||||||
|
|
||||||
async def process_new_chosen_inline_query(self, new_chosen_inline_querys):
|
async def process_new_chosen_inline_query(self, new_chosen_inline_querys):
|
||||||
await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys)
|
await self._process_updates(self.chosen_inline_handlers, new_chosen_inline_querys, 'chosen_inline_query')
|
||||||
|
|
||||||
async def process_new_callback_query(self, new_callback_querys):
|
async def process_new_callback_query(self, new_callback_querys):
|
||||||
await self._process_updates(self.callback_query_handlers, new_callback_querys)
|
await self._process_updates(self.callback_query_handlers, new_callback_querys, 'callback_query')
|
||||||
|
|
||||||
async def process_new_shipping_query(self, new_shipping_querys):
|
async def process_new_shipping_query(self, new_shipping_querys):
|
||||||
await self._process_updates(self.shipping_query_handlers, new_shipping_querys)
|
await self._process_updates(self.shipping_query_handlers, new_shipping_querys, 'shipping_query')
|
||||||
|
|
||||||
async def process_new_pre_checkout_query(self, pre_checkout_querys):
|
async def process_new_pre_checkout_query(self, pre_checkout_querys):
|
||||||
await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys)
|
await self._process_updates(self.pre_checkout_query_handlers, pre_checkout_querys, 'pre_checkout_query')
|
||||||
|
|
||||||
async def process_new_poll(self, polls):
|
async def process_new_poll(self, polls):
|
||||||
await self._process_updates(self.poll_handlers, polls)
|
await self._process_updates(self.poll_handlers, polls, 'poll')
|
||||||
|
|
||||||
async def process_new_poll_answer(self, poll_answers):
|
async def process_new_poll_answer(self, poll_answers):
|
||||||
await self._process_updates(self.poll_answer_handlers, poll_answers)
|
await self._process_updates(self.poll_answer_handlers, poll_answers, 'poll_answer')
|
||||||
|
|
||||||
async def process_new_my_chat_member(self, my_chat_members):
|
async def process_new_my_chat_member(self, my_chat_members):
|
||||||
await self._process_updates(self.my_chat_member_handlers, my_chat_members)
|
await self._process_updates(self.my_chat_member_handlers, my_chat_members, 'my_chat_member')
|
||||||
|
|
||||||
async def process_new_chat_member(self, chat_members):
|
async def process_new_chat_member(self, chat_members):
|
||||||
await self._process_updates(self.chat_member_handlers, chat_members)
|
await self._process_updates(self.chat_member_handlers, chat_members, 'chat_member')
|
||||||
|
|
||||||
async def process_chat_join_request(self, chat_join_request):
|
async def process_chat_join_request(self, chat_join_request):
|
||||||
await self._process_updates(self.chat_join_request_handlers, chat_join_request)
|
await self._process_updates(self.chat_join_request_handlers, chat_join_request, 'chat_join_request')
|
||||||
|
|
||||||
async def process_middlewares(self, update):
|
|
||||||
for update_type, middlewares in self.typed_middleware_handlers.items():
|
|
||||||
if getattr(update, update_type) is not None:
|
|
||||||
for typed_middleware_handler in middlewares:
|
|
||||||
try:
|
|
||||||
typed_middleware_handler(self, getattr(update, update_type))
|
|
||||||
except Exception as e:
|
|
||||||
e.args = e.args + (f'Typed middleware handler "{typed_middleware_handler.__qualname__}"',)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if len(self.default_middleware_handlers) > 0:
|
|
||||||
for default_middleware_handler in self.default_middleware_handlers:
|
|
||||||
try:
|
|
||||||
default_middleware_handler(self, update)
|
|
||||||
except Exception as e:
|
|
||||||
e.args = e.args + (f'Default middleware handler "{default_middleware_handler.__qualname__}"',)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
async def process_middlewares(self, update, update_type):
|
||||||
|
for middleware in self.middlewares:
|
||||||
|
if update_type in middleware.update_types:
|
||||||
|
return middleware
|
||||||
|
return None
|
||||||
|
|
||||||
async def __notify_update(self, new_messages):
|
async def __notify_update(self, new_messages):
|
||||||
if len(self.update_listener) == 0:
|
if len(self.update_listener) == 0:
|
||||||
@ -3759,56 +3767,16 @@ class AsyncTeleBot:
|
|||||||
elif isinstance(filter_check, asyncio_filters.AdvancedCustomFilter):
|
elif isinstance(filter_check, asyncio_filters.AdvancedCustomFilter):
|
||||||
return await filter_check.check(message, filter_value)
|
return await filter_check.check(message, filter_value)
|
||||||
else:
|
else:
|
||||||
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter!")
|
logger.error("Custom filter: wrong type. Should be SimpleCustomFilter or AdvancedCustomFilter.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def middleware_handler(self, update_types=None):
|
def setup_middleware(self, middleware):
|
||||||
"""
|
"""
|
||||||
Middleware handler decorator.
|
Setup middleware
|
||||||
|
:param middleware:
|
||||||
This decorator can be used to decorate functions that must be handled as middlewares before entering any other
|
|
||||||
message handlers
|
|
||||||
But, be careful and check type of the update inside the handler if more than one update_type is given
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
bot = TeleBot('TOKEN')
|
|
||||||
|
|
||||||
# Print post message text before entering to any post_channel handlers
|
|
||||||
@bot.middleware_handler(update_types=['channel_post', 'edited_channel_post'])
|
|
||||||
def print_channel_post_text(bot_instance, channel_post):
|
|
||||||
print(channel_post.text)
|
|
||||||
|
|
||||||
# Print update id before entering to any handlers
|
|
||||||
@bot.middleware_handler()
|
|
||||||
def print_channel_post_text(bot_instance, update):
|
|
||||||
print(update.update_id)
|
|
||||||
|
|
||||||
:param update_types: Optional list of update types that can be passed into the middleware handler.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(handler):
|
|
||||||
self.add_middleware_handler(handler, update_types)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def add_middleware_handler(self, handler, update_types=None):
|
|
||||||
"""
|
|
||||||
Add middleware handler
|
|
||||||
:param handler:
|
|
||||||
:param update_types:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not asyncio_helper.ENABLE_MIDDLEWARE:
|
self.middlewares.append(middleware)
|
||||||
raise RuntimeError("Middleware is not enabled. Use asyncio_helper.ENABLE_MIDDLEWARE.")
|
|
||||||
|
|
||||||
if update_types:
|
|
||||||
for update_type in update_types:
|
|
||||||
self.typed_middleware_handlers[update_type].append(handler)
|
|
||||||
else:
|
|
||||||
self.default_middleware_handlers.append(handler)
|
|
||||||
|
|
||||||
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
|
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1,146 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import threading
|
|
||||||
|
|
||||||
from telebot import apihelper
|
|
||||||
|
|
||||||
|
|
||||||
class HandlerBackend(object):
|
|
||||||
"""
|
|
||||||
Class for saving (next step|reply) handlers
|
|
||||||
"""
|
|
||||||
def __init__(self, handlers=None):
|
|
||||||
if handlers is None:
|
|
||||||
handlers = {}
|
|
||||||
self.handlers = handlers
|
|
||||||
|
|
||||||
async def register_handler(self, handler_group_id, handler):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def clear_handlers(self, handler_group_id):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def get_handlers(self, handler_group_id):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryHandlerBackend(HandlerBackend):
|
|
||||||
async def register_handler(self, handler_group_id, handler):
|
|
||||||
if handler_group_id in self.handlers:
|
|
||||||
self.handlers[handler_group_id].append(handler)
|
|
||||||
else:
|
|
||||||
self.handlers[handler_group_id] = [handler]
|
|
||||||
|
|
||||||
async def clear_handlers(self, handler_group_id):
|
|
||||||
self.handlers.pop(handler_group_id, None)
|
|
||||||
|
|
||||||
async def get_handlers(self, handler_group_id):
|
|
||||||
return self.handlers.pop(handler_group_id, None)
|
|
||||||
|
|
||||||
async def load_handlers(self, filename, del_file_after_loading):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class FileHandlerBackend(HandlerBackend):
|
|
||||||
def __init__(self, handlers=None, filename='./.handler-saves/handlers.save', delay=120):
|
|
||||||
super(FileHandlerBackend, self).__init__(handlers)
|
|
||||||
self.filename = filename
|
|
||||||
self.delay = delay
|
|
||||||
self.timer = threading.Timer(delay, self.save_handlers)
|
|
||||||
|
|
||||||
async def register_handler(self, handler_group_id, handler):
|
|
||||||
if handler_group_id in self.handlers:
|
|
||||||
self.handlers[handler_group_id].append(handler)
|
|
||||||
else:
|
|
||||||
self.handlers[handler_group_id] = [handler]
|
|
||||||
await self.start_save_timer()
|
|
||||||
|
|
||||||
async def clear_handlers(self, handler_group_id):
|
|
||||||
self.handlers.pop(handler_group_id, None)
|
|
||||||
await self.start_save_timer()
|
|
||||||
|
|
||||||
async def get_handlers(self, handler_group_id):
|
|
||||||
handlers = self.handlers.pop(handler_group_id, None)
|
|
||||||
await self.start_save_timer()
|
|
||||||
return handlers
|
|
||||||
|
|
||||||
async def start_save_timer(self):
|
|
||||||
if not self.timer.is_alive():
|
|
||||||
if self.delay <= 0:
|
|
||||||
self.save_handlers()
|
|
||||||
else:
|
|
||||||
self.timer = threading.Timer(self.delay, self.save_handlers)
|
|
||||||
self.timer.start()
|
|
||||||
|
|
||||||
async def save_handlers(self):
|
|
||||||
await self.dump_handlers(self.handlers, self.filename)
|
|
||||||
|
|
||||||
async def load_handlers(self, filename=None, del_file_after_loading=True):
|
|
||||||
if not filename:
|
|
||||||
filename = self.filename
|
|
||||||
tmp = await self.return_load_handlers(filename, del_file_after_loading=del_file_after_loading)
|
|
||||||
if tmp is not None:
|
|
||||||
self.handlers.update(tmp)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def dump_handlers(handlers, filename, file_mode="wb"):
|
|
||||||
dirs = filename.rsplit('/', maxsplit=1)[0]
|
|
||||||
os.makedirs(dirs, exist_ok=True)
|
|
||||||
|
|
||||||
with open(filename + ".tmp", file_mode) as file:
|
|
||||||
if (apihelper.CUSTOM_SERIALIZER is None):
|
|
||||||
pickle.dump(handlers, file)
|
|
||||||
else:
|
|
||||||
apihelper.CUSTOM_SERIALIZER.dump(handlers, file)
|
|
||||||
|
|
||||||
if os.path.isfile(filename):
|
|
||||||
os.remove(filename)
|
|
||||||
|
|
||||||
os.rename(filename + ".tmp", filename)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def return_load_handlers(filename, del_file_after_loading=True):
|
|
||||||
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
|
|
||||||
with open(filename, "rb") as file:
|
|
||||||
if (apihelper.CUSTOM_SERIALIZER is None):
|
|
||||||
handlers = pickle.load(file)
|
|
||||||
else:
|
|
||||||
handlers = apihelper.CUSTOM_SERIALIZER.load(file)
|
|
||||||
|
|
||||||
if del_file_after_loading:
|
|
||||||
os.remove(filename)
|
|
||||||
|
|
||||||
return handlers
|
|
||||||
|
|
||||||
|
|
||||||
class RedisHandlerBackend(HandlerBackend):
|
|
||||||
def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None):
|
|
||||||
super(RedisHandlerBackend, self).__init__(handlers)
|
|
||||||
from redis import Redis
|
|
||||||
self.prefix = prefix
|
|
||||||
self.redis = Redis(host, port, db, password)
|
|
||||||
|
|
||||||
async def _key(self, handle_group_id):
|
|
||||||
return ':'.join((self.prefix, str(handle_group_id)))
|
|
||||||
|
|
||||||
async def register_handler(self, handler_group_id, handler):
|
|
||||||
handlers = []
|
|
||||||
value = self.redis.get(self._key(handler_group_id))
|
|
||||||
if value:
|
|
||||||
handlers = pickle.loads(value)
|
|
||||||
handlers.append(handler)
|
|
||||||
self.redis.set(self._key(handler_group_id), pickle.dumps(handlers))
|
|
||||||
|
|
||||||
async def clear_handlers(self, handler_group_id):
|
|
||||||
self.redis.delete(self._key(handler_group_id))
|
|
||||||
|
|
||||||
async def get_handlers(self, handler_group_id):
|
|
||||||
handlers = None
|
|
||||||
value = self.redis.get(self._key(handler_group_id))
|
|
||||||
if value:
|
|
||||||
handlers = pickle.loads(value)
|
|
||||||
self.clear_handlers(handler_group_id)
|
|
||||||
return handlers
|
|
||||||
|
|
||||||
|
|
||||||
class StateMemory:
|
class StateMemory:
|
||||||
@ -341,3 +201,19 @@ class StateFileContext:
|
|||||||
await self.obj._save_data(old_data)
|
await self.obj._save_data(old_data)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMiddleware:
|
||||||
|
"""
|
||||||
|
Base class for middleware.
|
||||||
|
|
||||||
|
Your middlewares should be inherited from this class.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def pre_process(self, message, data):
|
||||||
|
raise NotImplementedError
|
||||||
|
async def post_process(self, message, data, exception):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -1,28 +1,19 @@
|
|||||||
import asyncio
|
import asyncio # for future uses
|
||||||
from time import time
|
from time import time
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from telebot import types
|
from telebot import types
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ujson as json
|
import ujson as json
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import requests
|
|
||||||
from requests.exceptions import HTTPError, ConnectionError, Timeout
|
|
||||||
|
|
||||||
try:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
from requests.packages.urllib3 import fields
|
|
||||||
format_header_param = fields.format_header_param
|
|
||||||
except ImportError:
|
|
||||||
format_header_param = None
|
|
||||||
|
|
||||||
API_URL = 'https://api.telegram.org/bot{0}/{1}'
|
API_URL = 'https://api.telegram.org/bot{0}/{1}'
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import telebot
|
||||||
from telebot import util
|
from telebot import util
|
||||||
|
|
||||||
class SessionBase:
|
class SessionBase:
|
||||||
@ -44,19 +35,16 @@ READ_TIMEOUT = 30
|
|||||||
|
|
||||||
LONG_POLLING_TIMEOUT = 10 # Should be positive, short polling should be used for testing purposes only (https://core.telegram.org/bots/api#getupdates)
|
LONG_POLLING_TIMEOUT = 10 # Should be positive, short polling should be used for testing purposes only (https://core.telegram.org/bots/api#getupdates)
|
||||||
|
|
||||||
|
logger = telebot.logger
|
||||||
|
|
||||||
RETRY_ON_ERROR = False
|
RETRY_ON_ERROR = False
|
||||||
RETRY_TIMEOUT = 2
|
RETRY_TIMEOUT = 2
|
||||||
MAX_RETRIES = 15
|
MAX_RETRIES = 15
|
||||||
|
|
||||||
CUSTOM_SERIALIZER = None
|
async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None):
|
||||||
CUSTOM_REQUEST_SENDER = None
|
|
||||||
|
|
||||||
ENABLE_MIDDLEWARE = False
|
|
||||||
|
|
||||||
async def _process_request(token, url, method='get', params=None, files=None):
|
|
||||||
async with await session_manager._get_new_session() as session:
|
async with await session_manager._get_new_session() as session:
|
||||||
async with session.get(API_URL.format(token, url), params=params, data=files) as response:
|
async with session.get(API_URL.format(token, url), params=params, data=files, timeout=request_timeout) as response:
|
||||||
|
logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout).replace(token, token.split(':')[0] + ":{TOKEN}"))
|
||||||
json_result = await _check_result(url, response)
|
json_result = await _check_result(url, response)
|
||||||
if json_result:
|
if json_result:
|
||||||
return json_result['result']
|
return json_result['result']
|
||||||
@ -155,7 +143,7 @@ async def get_webhook_info(token, timeout=None):
|
|||||||
|
|
||||||
|
|
||||||
async def get_updates(token, offset=None, limit=None,
|
async def get_updates(token, offset=None, limit=None,
|
||||||
timeout=None, allowed_updates=None, long_polling_timeout=None):
|
timeout=None, allowed_updates=None, request_timeout=None):
|
||||||
method_name = 'getUpdates'
|
method_name = 'getUpdates'
|
||||||
params = {}
|
params = {}
|
||||||
if offset:
|
if offset:
|
||||||
@ -166,8 +154,7 @@ async def get_updates(token, offset=None, limit=None,
|
|||||||
params['timeout'] = timeout
|
params['timeout'] = timeout
|
||||||
elif allowed_updates:
|
elif allowed_updates:
|
||||||
params['allowed_updates'] = allowed_updates
|
params['allowed_updates'] = allowed_updates
|
||||||
params['long_polling_timeout'] = long_polling_timeout if long_polling_timeout else LONG_POLLING_TIMEOUT
|
return await _process_request(token, method_name, params=params, request_timeout=request_timeout)
|
||||||
return await _process_request(token, method_name, params=params)
|
|
||||||
|
|
||||||
async def _check_result(method_name, result):
|
async def _check_result(method_name, result):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user