From a3cda2e0ff9b0bd9ac0d210009a5b7c74fb36935 Mon Sep 17 00:00:00 2001 From: _run Date: Mon, 24 Jan 2022 17:15:04 +0400 Subject: [PATCH] Updated sync and async. Fixes and new features. --- telebot/__init__.py | 151 +++++++++++----- telebot/async_telebot.py | 168 ++++++++++++------ telebot/asyncio_filters.py | 20 ++- telebot/asyncio_handler_backends.py | 200 --------------------- telebot/asyncio_helper.py | 60 ++++--- telebot/asyncio_storage/__init__.py | 13 ++ telebot/asyncio_storage/base_storage.py | 69 ++++++++ telebot/asyncio_storage/memory_storage.py | 64 +++++++ telebot/asyncio_storage/pickle_storage.py | 107 ++++++++++++ telebot/asyncio_storage/redis_storage.py | 178 +++++++++++++++++++ telebot/asyncio_types.py | 1 + telebot/custom_filters.py | 20 ++- telebot/handler_backends.py | 204 +--------------------- telebot/storage/__init__.py | 13 ++ telebot/storage/base_storage.py | 65 +++++++ telebot/storage/memory_storage.py | 64 +++++++ telebot/storage/pickle_storage.py | 112 ++++++++++++ telebot/storage/redis_storage.py | 176 +++++++++++++++++++ 18 files changed, 1160 insertions(+), 525 deletions(-) create mode 100644 telebot/asyncio_storage/__init__.py create mode 100644 telebot/asyncio_storage/base_storage.py create mode 100644 telebot/asyncio_storage/memory_storage.py create mode 100644 telebot/asyncio_storage/pickle_storage.py create mode 100644 telebot/asyncio_storage/redis_storage.py create mode 100644 telebot/asyncio_types.py create mode 100644 telebot/storage/__init__.py create mode 100644 telebot/storage/base_storage.py create mode 100644 telebot/storage/memory_storage.py create mode 100644 telebot/storage/pickle_storage.py create mode 100644 telebot/storage/redis_storage.py diff --git a/telebot/__init__.py b/telebot/__init__.py index 5fed5c1..fca501c 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -13,7 +13,8 @@ from typing import Any, Callable, List, Optional, Union import telebot.util import telebot.types - +# storage +from telebot.storage import StatePickleStorage, StateMemoryStorage logger = logging.getLogger('TeleBot') @@ -28,7 +29,7 @@ logger.addHandler(console_output_handler) logger.setLevel(logging.ERROR) from telebot import apihelper, util, types -from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, StateMemory, StateFile +from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter @@ -148,7 +149,7 @@ class TeleBot: def __init__( self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2, next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0, - suppress_middleware_excepions=False + suppress_middleware_excepions=False, state_storage=StateMemoryStorage() ): """ :param token: bot API token @@ -193,7 +194,7 @@ class TeleBot: self.custom_filters = {} self.state_handlers = [] - self.current_states = StateMemory() + self.current_states = state_storage if apihelper.ENABLE_MIDDLEWARE: @@ -251,7 +252,7 @@ class TeleBot: :param filename: Filename of saving file """ - self.current_states = StateFile(filename=filename) + self.current_states = StatePickleStorage(filename=filename) self.current_states.create_dir() def enable_save_reply_handlers(self, delay=120, filename="./.handler-saves/reply.save"): @@ -777,6 +778,13 @@ class TeleBot: logger.info('Stopped polling.') def _exec_task(self, task, *args, **kwargs): + if kwargs.get('task_type') == 'handler': + pass_bot = kwargs.get('pass_bot') + kwargs.pop('pass_bot') + kwargs.pop('task_type') + if pass_bot: + kwargs['bot'] = self + if self.threaded: self.worker_pool.put(task, *args, **kwargs) else: @@ -2531,40 +2539,59 @@ class TeleBot: chat_id = message.chat.id self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs) - def set_state(self, chat_id: int, state: Union[int, str]): + def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None: """ Sets a new state of a user. :param chat_id: :param state: new state. can be string or integer. """ - self.current_states.add_state(chat_id, state) + if chat_id is None: + chat_id = user_id + self.current_states.set_state(chat_id, user_id, state) - def delete_state(self, chat_id: int): + def reset_data(self, user_id: int, chat_id: int=None): + """ + Reset data for a user in chat. + :param user_id: + :param chat_id: + """ + if chat_id is None: + chat_id = user_id + self.current_states.reset_data(chat_id, user_id) + def delete_state(self, user_id: int, chat_id: int=None) -> None: """ Delete the current state of a user. :param chat_id: :return: """ - self.current_states.delete_state(chat_id) + if chat_id is None: + chat_id = user_id + self.current_states.delete_state(chat_id, user_id) - def retrieve_data(self, chat_id: int): - return self.current_states.retrieve_data(chat_id) + def retrieve_data(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]: + if chat_id is None: + chat_id = user_id + return self.current_states.get_interactive_data(chat_id, user_id) - def get_state(self, chat_id: int): + def get_state(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]: """ Get current state of a user. :param chat_id: :return: state of a user """ - return self.current_states.current_state(chat_id) + if chat_id is None: + chat_id = user_id + return self.current_states.get_state(chat_id, user_id) - def add_data(self, chat_id: int, **kwargs): + def add_data(self, user_id: int, chat_id:int=None, **kwargs): """ Add data to states. :param chat_id: """ + if chat_id is None: + chat_id = user_id for key, value in kwargs.items(): - self.current_states.add_data(chat_id, key, value) + self.current_states.set_data(chat_id, user_id, key, value) def register_next_step_handler_by_chat_id( self, chat_id: Union[int, str], callback: Callable, *args, **kwargs) -> None: @@ -2632,7 +2659,7 @@ class TeleBot: @staticmethod - def _build_handler_dict(handler, **filters): + def _build_handler_dict(handler, pass_bot=False, **filters): """ Builds a dictionary for a handler :param handler: @@ -2641,6 +2668,7 @@ class TeleBot: """ return { 'function': handler, + 'pass_bot': pass_bot, 'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None} # Remove None values, they are skipped in _test_filter anyway #'filters': filters @@ -2686,7 +2714,7 @@ class TeleBot: :return: """ if not apihelper.ENABLE_MIDDLEWARE: - raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE.") + raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE before initialising TeleBot.") if update_types: for update_type in update_types: @@ -2694,6 +2722,27 @@ class TeleBot: else: self.default_middleware_handlers.append(handler) + # function register_middleware_handler + def register_middleware_handler(self, callback, update_types=None): + """ + Middleware handler decorator. + + This function will create a decorator that can be used to decorate functions that must be handled as middlewares before entering any other + message handlers + But, be careful and check type of the update inside the handler if more than one update_type is given + + Example: + + bot = TeleBot('TOKEN') + + bot.register_middleware_handler(print_channel_post_text, update_types=['channel_post', 'edited_channel_post']) + + :param update_types: Optional list of update types that can be passed into the middleware handler. + + """ + + self.add_middleware_handler(callback, update_types) + def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs): """ Message handler decorator. @@ -2766,7 +2815,7 @@ class TeleBot: """ self.message_handlers.append(handler_dict) - def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): + def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs): """ Registers message handler. :param callback: function to be called @@ -2775,6 +2824,7 @@ class TeleBot: :param regexp: :param func: :param chat_types: True for private chat + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ if isinstance(commands, str): @@ -2791,6 +2841,7 @@ class TeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_message_handler(handler_dict) @@ -2838,7 +2889,7 @@ class TeleBot: """ self.edited_message_handlers.append(handler_dict) - def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): + def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs): """ Registers edited message handler. :param callback: function to be called @@ -2847,6 +2898,7 @@ class TeleBot: :param regexp: :param func: :param chat_types: True for private chat + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ if isinstance(commands, str): @@ -2863,6 +2915,7 @@ class TeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_edited_message_handler(handler_dict) @@ -2908,7 +2961,7 @@ class TeleBot: """ self.channel_post_handlers.append(handler_dict) - def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): + def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs): """ Registers channel post message handler. :param callback: function to be called @@ -2916,6 +2969,7 @@ class TeleBot: :param commands: list of commands :param regexp: :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ if isinstance(commands, str): @@ -2931,6 +2985,7 @@ class TeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_channel_post_handler(handler_dict) @@ -2975,7 +3030,7 @@ class TeleBot: """ self.edited_channel_post_handlers.append(handler_dict) - def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): + def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs): """ Registers edited channel post message handler. :param callback: function to be called @@ -2983,6 +3038,7 @@ class TeleBot: :param commands: list of commands :param regexp: :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ if isinstance(commands, str): @@ -2998,6 +3054,7 @@ class TeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_edited_channel_post_handler(handler_dict) @@ -3024,14 +3081,15 @@ class TeleBot: """ self.inline_handlers.append(handler_dict) - def register_inline_handler(self, callback, func, **kwargs): + def register_inline_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers inline handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_inline_handler(handler_dict) def chosen_inline_handler(self, func, **kwargs): @@ -3057,14 +3115,15 @@ class TeleBot: """ self.chosen_inline_handlers.append(handler_dict) - def register_chosen_inline_handler(self, callback, func, **kwargs): + def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers chosen inline handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chosen_inline_handler(handler_dict) def callback_query_handler(self, func, **kwargs): @@ -3090,14 +3149,15 @@ class TeleBot: """ self.callback_query_handlers.append(handler_dict) - def register_callback_query_handler(self, callback, func, **kwargs): + def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers callback query handler.. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_callback_query_handler(handler_dict) def shipping_query_handler(self, func, **kwargs): @@ -3123,14 +3183,15 @@ class TeleBot: """ self.shipping_query_handlers.append(handler_dict) - def register_shipping_query_handler(self, callback, func, **kwargs): + def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers shipping query handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_shipping_query_handler(handler_dict) def pre_checkout_query_handler(self, func, **kwargs): @@ -3156,14 +3217,15 @@ class TeleBot: """ self.pre_checkout_query_handlers.append(handler_dict) - def register_pre_checkout_query_handler(self, callback, func, **kwargs): + def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers pre-checkout request handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_pre_checkout_query_handler(handler_dict) def poll_handler(self, func, **kwargs): @@ -3189,14 +3251,15 @@ class TeleBot: """ self.poll_handlers.append(handler_dict) - def register_poll_handler(self, callback, func, **kwargs): + def register_poll_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers poll handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_poll_handler(handler_dict) def poll_answer_handler(self, func=None, **kwargs): @@ -3222,14 +3285,15 @@ class TeleBot: """ self.poll_answer_handlers.append(handler_dict) - def register_poll_answer_handler(self, callback, func, **kwargs): + def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers poll answer handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_poll_answer_handler(handler_dict) def my_chat_member_handler(self, func=None, **kwargs): @@ -3255,14 +3319,15 @@ class TeleBot: """ self.my_chat_member_handlers.append(handler_dict) - def register_my_chat_member_handler(self, callback, func=None, **kwargs): + def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers my chat member handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_my_chat_member_handler(handler_dict) def chat_member_handler(self, func=None, **kwargs): @@ -3288,14 +3353,15 @@ class TeleBot: """ self.chat_member_handlers.append(handler_dict) - def register_chat_member_handler(self, callback, func=None, **kwargs): + def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers chat member handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chat_member_handler(handler_dict) def chat_join_request_handler(self, func=None, **kwargs): @@ -3321,14 +3387,15 @@ class TeleBot: """ self.chat_join_request_handlers.append(handler_dict) - def register_chat_join_request_handler(self, callback, func=None, **kwargs): + def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers chat join request handler. :param callback: function to be called :param func: + :param pass_bot: Pass TeleBot to handler. :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chat_join_request_handler(handler_dict) def _test_message_handler(self, message_handler, message): @@ -3409,7 +3476,7 @@ class TeleBot: for message in new_messages: for message_handler in handlers: if self._test_message_handler(message_handler, message): - self._exec_task(message_handler['function'], message) + self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler') break diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 6bfa799..f741f63 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -13,6 +13,9 @@ import telebot.util import telebot.types +# storages +from telebot.asyncio_storage import StateMemoryStorage, StatePickleStorage + from inspect import signature from telebot import logger @@ -161,7 +164,7 @@ class AsyncTeleBot: """ def __init__(self, token: str, parse_mode: Optional[str]=None, offset=None, - exception_handler=None) -> None: # TODO: ADD TYPEHINTS + exception_handler=None, states_storage=StateMemoryStorage()) -> None: # TODO: ADD TYPEHINTS self.token = token self.offset = offset @@ -190,12 +193,13 @@ class AsyncTeleBot: self.custom_filters = {} self.state_handlers = [] - self.current_states = asyncio_handler_backends.StateMemory() + self.current_states = states_storage self.middlewares = [] - + async def close_session(self): + await asyncio_helper.session_manager.session.close() async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None, timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> List[types.Update]: json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout) @@ -299,7 +303,7 @@ class AsyncTeleBot: updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout) if updates: self.offset = updates[-1].update_id + 1 - self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates + asyncio.create_task(self.process_new_updates(updates)) # Seperate task for processing updates if interval: await asyncio.sleep(interval) except KeyboardInterrupt: @@ -322,6 +326,8 @@ class AsyncTeleBot: continue else: break + except KeyboardInterrupt: + return except Exception as e: logger.error('Cause exception while getting updates.') if non_stop: @@ -333,6 +339,7 @@ class AsyncTeleBot: finally: self._polling = False + await self.close_session() logger.warning('Polling is stopped.') @@ -346,31 +353,47 @@ class AsyncTeleBot: :param messages: :return: """ + tasks = [] for message in messages: middleware = await self.process_middlewares(message, update_type) - self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware)) + tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware)) + asyncio.gather(*tasks) async def _run_middlewares_and_handlers(self, handlers, message, middleware): handler_error = None data = {} - for message_handler in handlers: - process_update = await self._test_message_handler(message_handler, message) + process_handler = True + middleware_result = await middleware.pre_process(message, data) + if isinstance(middleware_result, SkipHandler): + await middleware.post_process(message, data, handler_error) + process_handler = False + if isinstance(middleware_result, CancelUpdate): + return + for handler in handlers: + if not process_handler: + break + + process_update = await self._test_message_handler(handler, message) if not process_update: continue elif process_update: - if middleware: - middleware_result = await middleware.pre_process(message, data) - if isinstance(middleware_result, SkipHandler): - await middleware.post_process(message, data, handler_error) - break - if isinstance(middleware_result, CancelUpdate): - return try: - if "data" in signature(message_handler['function']).parameters: - await message_handler['function'](message, data) - else: - await message_handler['function'](message) + params = [] + + for i in signature(handler['function']).parameters: + params.append(i) + if len(params) == 1: + await handler['function'](message) + break + if params[1] == 'data' and handler.get('pass_bot') is True: + await handler['function'](message, data, self) + break + elif params[1] == 'data' and handler.get('pass_bot') is False: + await handler['function'](message, data) + break + elif params[1] != 'data' and handler.get('pass_bot') is True: + await handler['function'](message, self) break except Exception as e: handler_error = e @@ -687,7 +710,7 @@ class AsyncTeleBot: """ self.message_handlers.append(handler_dict) - def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): + def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs): """ Registers message handler. :param callback: function to be called @@ -696,8 +719,11 @@ class AsyncTeleBot: :param regexp: :param func: :param chat_types: True for private chat + :param pass_bot: True if you want to get TeleBot instance in your handler :return: decorated function """ + if content_types is None: + content_types = ["text"] if isinstance(commands, str): logger.warning("register_message_handler: 'commands' filter should be List of strings (commands), not string.") commands = [commands] @@ -712,6 +738,7 @@ class AsyncTeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_message_handler(handler_dict) @@ -759,7 +786,7 @@ class AsyncTeleBot: """ self.edited_message_handlers.append(handler_dict) - def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): + def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs): """ Registers edited message handler. :param callback: function to be called @@ -784,6 +811,7 @@ class AsyncTeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_edited_message_handler(handler_dict) @@ -829,7 +857,7 @@ class AsyncTeleBot: """ self.channel_post_handlers.append(handler_dict) - def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): + def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs): """ Registers channel post message handler. :param callback: function to be called @@ -852,6 +880,7 @@ class AsyncTeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_channel_post_handler(handler_dict) @@ -896,7 +925,7 @@ class AsyncTeleBot: """ self.edited_channel_post_handlers.append(handler_dict) - def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): + def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs): """ Registers edited channel post message handler. :param callback: function to be called @@ -919,6 +948,7 @@ class AsyncTeleBot: commands=commands, regexp=regexp, func=func, + pass_bot=pass_bot, **kwargs) self.add_edited_channel_post_handler(handler_dict) @@ -945,14 +975,14 @@ class AsyncTeleBot: """ self.inline_handlers.append(handler_dict) - def register_inline_handler(self, callback, func, **kwargs): + def register_inline_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers inline handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_inline_handler(handler_dict) def chosen_inline_handler(self, func, **kwargs): @@ -978,14 +1008,14 @@ class AsyncTeleBot: """ self.chosen_inline_handlers.append(handler_dict) - def register_chosen_inline_handler(self, callback, func, **kwargs): + def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers chosen inline handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chosen_inline_handler(handler_dict) def callback_query_handler(self, func, **kwargs): @@ -1011,14 +1041,14 @@ class AsyncTeleBot: """ self.callback_query_handlers.append(handler_dict) - def register_callback_query_handler(self, callback, func, **kwargs): + def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers callback query handler.. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_callback_query_handler(handler_dict) def shipping_query_handler(self, func, **kwargs): @@ -1044,14 +1074,14 @@ class AsyncTeleBot: """ self.shipping_query_handlers.append(handler_dict) - def register_shipping_query_handler(self, callback, func, **kwargs): + def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers shipping query handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_shipping_query_handler(handler_dict) def pre_checkout_query_handler(self, func, **kwargs): @@ -1077,14 +1107,14 @@ class AsyncTeleBot: """ self.pre_checkout_query_handlers.append(handler_dict) - def register_pre_checkout_query_handler(self, callback, func, **kwargs): + def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers pre-checkout request handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_pre_checkout_query_handler(handler_dict) def poll_handler(self, func, **kwargs): @@ -1110,14 +1140,14 @@ class AsyncTeleBot: """ self.poll_handlers.append(handler_dict) - def register_poll_handler(self, callback, func, **kwargs): + def register_poll_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers poll handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_poll_handler(handler_dict) def poll_answer_handler(self, func=None, **kwargs): @@ -1143,14 +1173,14 @@ class AsyncTeleBot: """ self.poll_answer_handlers.append(handler_dict) - def register_poll_answer_handler(self, callback, func, **kwargs): + def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs): """ Registers poll answer handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_poll_answer_handler(handler_dict) def my_chat_member_handler(self, func=None, **kwargs): @@ -1176,14 +1206,14 @@ class AsyncTeleBot: """ self.my_chat_member_handlers.append(handler_dict) - def register_my_chat_member_handler(self, callback, func=None, **kwargs): + def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers my chat member handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_my_chat_member_handler(handler_dict) def chat_member_handler(self, func=None, **kwargs): @@ -1209,14 +1239,14 @@ class AsyncTeleBot: """ self.chat_member_handlers.append(handler_dict) - def register_chat_member_handler(self, callback, func=None, **kwargs): + def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers chat member handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chat_member_handler(handler_dict) def chat_join_request_handler(self, func=None, **kwargs): @@ -1242,18 +1272,18 @@ class AsyncTeleBot: """ self.chat_join_request_handlers.append(handler_dict) - def register_chat_join_request_handler(self, callback, func=None, **kwargs): + def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs): """ Registers chat join request handler. :param callback: function to be called :param func: :return: decorated function """ - handler_dict = self._build_handler_dict(callback, func=func, **kwargs) + handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs) self.add_chat_join_request_handler(handler_dict) @staticmethod - def _build_handler_dict(handler, **filters): + def _build_handler_dict(handler, pass_bot=False, **filters): """ Builds a dictionary for a handler :param handler: @@ -1262,6 +1292,7 @@ class AsyncTeleBot: """ return { 'function': handler, + 'pass_bot': pass_bot, 'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None} # Remove None values, they are skipped in _test_filter anyway #'filters': filters @@ -1324,8 +1355,7 @@ class AsyncTeleBot: :param filename: Filename of saving file """ - self.current_states = asyncio_handler_backends.StateFile(filename=filename) - self.current_states.create_dir() + self.current_states = StatePickleStorage(file_path=filename) async def set_webhook(self, url=None, certificate=None, max_connections=None, allowed_updates=None, ip_address=None, drop_pending_updates = None, timeout=None): @@ -1356,6 +1386,8 @@ class AsyncTeleBot: return await asyncio_helper.set_webhook(self.token, url, certificate, max_connections, allowed_updates, ip_address, drop_pending_updates, timeout) + + async def delete_webhook(self, drop_pending_updates=None, timeout=None): """ Use this method to remove webhook integration if you decide to switch back to getUpdates. @@ -1366,6 +1398,12 @@ class AsyncTeleBot: """ return await asyncio_helper.delete_webhook(self.token, drop_pending_updates, timeout) + async def remove_webhook(self): + """ + Alternative for delete_webhook but uses set_webhook + """ + self.set_webhook() + async def get_webhook_info(self, timeout=None): """ Use this method to get current webhook status. Requires no parameters. @@ -3019,37 +3057,57 @@ class AsyncTeleBot: return await asyncio_helper.delete_sticker_from_set(self.token, sticker) - async def set_state(self, chat_id, state): + async def set_state(self, user_id: int, state: str, chat_id: int=None): """ Sets a new state of a user. :param chat_id: :param state: new state. can be string or integer. """ - await self.current_states.add_state(chat_id, state) + if not chat_id: + chat_id = user_id + await self.current_states.set_state(chat_id, user_id, state) - async def delete_state(self, chat_id): + async def reset_data(self, user_id: int, chat_id: int=None): + """ + Reset data for a user in chat. + :param user_id: + :param chat_id: + """ + if chat_id is None: + chat_id = user_id + await self.current_states.reset_data(chat_id, user_id) + + async def delete_state(self, user_id: int, chat_id:int=None): """ Delete the current state of a user. :param chat_id: :return: """ - await self.current_states.delete_state(chat_id) + if not chat_id: + chat_id = user_id + await self.current_states.delete_state(chat_id, user_id) - def retrieve_data(self, chat_id): - return self.current_states.retrieve_data(chat_id) + def retrieve_data(self, user_id: int, chat_id: int=None): + if not chat_id: + chat_id = user_id + return self.current_states.get_interactive_data(chat_id, user_id) - async def get_state(self, chat_id): + async def get_state(self, user_id, chat_id: int=None): """ Get current state of a user. :param chat_id: :return: state of a user """ - return await self.current_states.current_state(chat_id) + if not chat_id: + chat_id = user_id + return await self.current_states.get_state(chat_id, user_id) - async def add_data(self, chat_id, **kwargs): + async def add_data(self, user_id: int, chat_id: int=None, **kwargs): """ Add data to states. :param chat_id: """ + if not chat_id: + chat_id = user_id for key, value in kwargs.items(): - await self.current_states.add_data(chat_id, key, value) + await self.current_states.set_data(chat_id, user_id, key, value) diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index cce7017..5b193fd 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -159,11 +159,21 @@ class StateFilter(AdvancedCustomFilter): key = 'state' async def check(self, message, text): - result = await self.bot.current_states.current_state(message.from_user.id) - if result is False: return False - elif text == '*': return True - elif type(text) is list: return result in text - return result == text + if text == '*': return True + if message.chat.type == 'group': + group_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id) + if group_state == text: + return True + elif group_state in text and type(text) is list: + return True + + + else: + user_state = await self.bot.current_states.get_state(message.chat.id,message.from_user.id) + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True class IsDigitFilter(SimpleCustomFilter): """ diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 0a78a90..08db40f 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -3,206 +3,6 @@ import pickle -class StateMemory: - def __init__(self): - self._states = {} - - async def add_state(self, chat_id, state): - """ - Add a state. - :param chat_id: - :param state: new state - """ - if chat_id in self._states: - - self._states[chat_id]['state'] = state - else: - self._states[chat_id] = {'state': state,'data': {}} - - async def current_state(self, chat_id): - """Current state""" - if chat_id in self._states: return self._states[chat_id]['state'] - else: return False - - async def delete_state(self, chat_id): - """Delete a state""" - self._states.pop(chat_id) - - def get_data(self, chat_id): - return self._states[chat_id]['data'] - - async def set(self, chat_id, new_state): - """ - Set a new state for a user. - :param chat_id: - :param new_state: new_state of a user - """ - await self.add_state(chat_id,new_state) - - async def add_data(self, chat_id, key, value): - result = self._states[chat_id]['data'][key] = value - return result - - async def finish(self, chat_id): - """ - Finish(delete) state of a user. - :param chat_id: - """ - await self.delete_state(chat_id) - - def retrieve_data(self, chat_id): - """ - Save input text. - - Usage: - with bot.retrieve_data(message.chat.id) as data: - data['name'] = message.text - - Also, at the end of your 'Form' you can get the name: - data['name'] - """ - return StateContext(self, chat_id) - - -class StateFile: - """ - Class to save states in a file. - """ - def __init__(self, filename): - self.file_path = filename - - async def add_state(self, chat_id, state): - """ - Add a state. - :param chat_id: - :param state: new state - """ - states_data = self.read_data() - if chat_id in states_data: - states_data[chat_id]['state'] = state - return await self.save_data(states_data) - else: - states_data[chat_id] = {'state': state,'data': {}} - return await self.save_data(states_data) - - - async def current_state(self, chat_id): - """Current state.""" - states_data = self.read_data() - if chat_id in states_data: return states_data[chat_id]['state'] - else: return False - - async def delete_state(self, chat_id): - """Delete a state""" - states_data = self.read_data() - states_data.pop(chat_id) - await self.save_data(states_data) - - def read_data(self): - """ - Read the data from file. - """ - file = open(self.file_path, 'rb') - states_data = pickle.load(file) - file.close() - return states_data - - def create_dir(self): - """ - Create directory .save-handlers. - """ - dirs = self.file_path.rsplit('/', maxsplit=1)[0] - os.makedirs(dirs, exist_ok=True) - if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: - pickle.dump({}, file) - - async def save_data(self, new_data): - """ - Save data after editing. - :param new_data: - """ - with open(self.file_path, 'wb+') as state_file: - pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL) - return True - - def get_data(self, chat_id): - return self.read_data()[chat_id]['data'] - - async def set(self, chat_id, new_state): - """ - Set a new state for a user. - :param chat_id: - :param new_state: new_state of a user - - """ - await self.add_state(chat_id,new_state) - - async def add_data(self, chat_id, key, value): - states_data = self.read_data() - result = states_data[chat_id]['data'][key] = value - await self.save_data(result) - - return result - - async def finish(self, chat_id): - """ - Finish(delete) state of a user. - :param chat_id: - """ - await self.delete_state(chat_id) - - def retrieve_data(self, chat_id): - """ - Save input text. - - Usage: - with bot.retrieve_data(message.chat.id) as data: - data['name'] = message.text - - Also, at the end of your 'Form' you can get the name: - data['name'] - """ - return StateFileContext(self, chat_id) - - -class StateContext: - """ - Class for data. - """ - def __init__(self , obj: StateMemory, chat_id) -> None: - self.obj = obj - self.chat_id = chat_id - self.data = obj.get_data(chat_id) - - async def __aenter__(self): - return self.data - - async def __aexit__(self, exc_type, exc_val, exc_tb): - return - -class StateFileContext: - """ - Class for data. - """ - def __init__(self , obj: StateFile, chat_id) -> None: - self.obj = obj - self.chat_id = chat_id - self.data = None - - async def __aenter__(self): - self.data = self.obj.get_data(self.chat_id) - return self.data - - async def __aexit__(self, exc_type, exc_val, exc_tb): - old_data = self.obj.read_data() - for i in self.data: - old_data[self.chat_id]['data'][i] = self.data.get(i) - await self.obj.save_data(old_data) - - return - - class BaseMiddleware: """ Base class for middleware. diff --git a/telebot/asyncio_helper.py b/telebot/asyncio_helper.py index e36a974..fe4c13e 100644 --- a/telebot/asyncio_helper.py +++ b/telebot/asyncio_helper.py @@ -12,16 +12,8 @@ API_URL = 'https://api.telegram.org/bot{0}/{1}' from datetime import datetime import telebot -from telebot import util +from telebot import util, logger -class SessionBase: - def __init__(self) -> None: - self.session = None - async def _get_new_session(self): - self.session = aiohttp.ClientSession() - return self.session - -session_manager = SessionBase() proxy = None session = None @@ -36,6 +28,29 @@ REQUEST_TIMEOUT = 10 MAX_RETRIES = 3 logger = telebot.logger + +REQUEST_LIMIT = 50 + +class SessionManager: + def __init__(self) -> None: + self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT)) + + async def create_session(self): + self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT)) + return self.session + + async def get_session(self): + if self.session.closed: + self.session = await self.create_session() + + if not self.session._loop.is_running(): + await self.session.close() + self.session = await self.create_session() + return self.session + + +session_manager = SessionManager() + async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None): params = prepare_data(params, files) if request_timeout is None: @@ -43,19 +58,21 @@ async def _process_request(token, url, method='get', params=None, files=None, re timeout = aiohttp.ClientTimeout(total=request_timeout) got_result = False current_try=0 - async with await session_manager._get_new_session() as session: - while not got_result and current_try None: + pass + + async def set_data(self, chat_id, user_id, key, value): + """ + Set data for a user in a particular chat. + """ + raise NotImplementedError + + async def get_data(self, chat_id, user_id): + """ + Get data for a user in a particular chat. + """ + raise NotImplementedError + + async def set_state(self, chat_id, user_id, state): + """ + Set state for a particular user. + + ! Note that you should create a + record if it does not exist, and + if a record with state already exists, + you need to update a record. + """ + raise NotImplementedError + + async def delete_state(self, chat_id, user_id): + """ + Delete state for a particular user. + """ + raise NotImplementedError + + async def reset_data(self, chat_id, user_id): + """ + Reset data for a particular user in a chat. + """ + raise NotImplementedError + + async def get_state(self, chat_id, user_id): + raise NotImplementedError + + async def save(chat_id, user_id, data): + raise NotImplementedError + + + +class StateContext: + """ + Class for data. + """ + + def __init__(self, obj, chat_id, user_id): + self.obj = obj + self.data = None + self.chat_id = chat_id + self.user_id = user_id + + + + async def __aenter__(self): + self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id)) + return self.data + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py new file mode 100644 index 0000000..ab9c486 --- /dev/null +++ b/telebot/asyncio_storage/memory_storage.py @@ -0,0 +1,64 @@ +from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext + +class StateMemoryStorage(StateStorageBase): + def __init__(self) -> None: + self.data = {} + # + # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} + + + async def set_state(self, chat_id, user_id, state): + if chat_id in self.data: + if user_id in self.data[chat_id]: + self.data[chat_id][user_id]['state'] = state + return True + else: + self.data[chat_id][user_id] = {'state': state, 'data': {}} + return True + self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + return True + + async def delete_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + del self.data[chat_id][user_id] + if chat_id == user_id: + del self.data[chat_id] + + return True + + return False + + + async def get_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['state'] + + return None + async def get_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['data'] + + return None + + async def reset_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'] = {} + return True + return False + + async def set_data(self, chat_id, user_id, key, value): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'][key] = value + return True + raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + + def get_interactive_data(self, chat_id, user_id): + return StateContext(self, chat_id, user_id) + + async def save(self, chat_id, user_id, data): + self.data[chat_id][user_id]['data'] = data \ No newline at end of file diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py new file mode 100644 index 0000000..81ef46c --- /dev/null +++ b/telebot/asyncio_storage/pickle_storage.py @@ -0,0 +1,107 @@ +from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +import os + + +import pickle + + +class StatePickleStorage(StateStorageBase): + def __init__(self, file_path="./.state-save/states.pkl") -> None: + self.file_path = file_path + self.create_dir() + self.data = self.read() + + async def convert_old_to_new(self): + # old looks like: + # {1: {'state': 'start', 'data': {'name': 'John'}} + # we should update old version pickle to new. + # new looks like: + # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} + new_data = {} + for key, value in self.data.items(): + # this returns us id and dict with data and state + new_data[key] = {key: value} # convert this to new + # pass it to global data + self.data = new_data + self.update_data() # update data in file + + def create_dir(self): + """ + Create directory .save-handlers. + """ + dirs = self.file_path.rsplit('/', maxsplit=1)[0] + os.makedirs(dirs, exist_ok=True) + if not os.path.isfile(self.file_path): + with open(self.file_path,'wb') as file: + pickle.dump({}, file) + + def read(self): + file = open(self.file_path, 'rb') + data = pickle.load(file) + file.close() + return data + + def update_data(self): + file = open(self.file_path, 'wb+') + pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) + file.close() + + async def set_state(self, chat_id, user_id, state): + if chat_id in self.data: + if user_id in self.data[chat_id]: + self.data[chat_id][user_id]['state'] = state + return True + else: + self.data[chat_id][user_id] = {'state': state, 'data': {}} + return True + self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + self.update_data() + return True + + async def delete_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + del self.data[chat_id][user_id] + if chat_id == user_id: + del self.data[chat_id] + self.update_data() + return True + + return False + + + async def get_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['state'] + + return None + async def get_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['data'] + + return None + + async def reset_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'] = {} + self.update_data() + return True + return False + + async def set_data(self, chat_id, user_id, key, value): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'][key] = value + self.update_data() + return True + raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + + def get_interactive_data(self, chat_id, user_id): + return StateContext(self, chat_id, user_id) + + async def save(self, chat_id, user_id, data): + self.data[chat_id][user_id]['data'] = data + self.update_data() \ No newline at end of file diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py new file mode 100644 index 0000000..f7ca0fa --- /dev/null +++ b/telebot/asyncio_storage/redis_storage.py @@ -0,0 +1,178 @@ +from pickle import FALSE +from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +import json + +redis_installed = True +try: + import aioredis +except: + redis_installed = False + + +class StateRedisStorage(StateStorageBase): + """ + This class is for Redis storage. + This will work only for states. + To use it, just pass this class to: + TeleBot(storage=StateRedisStorage()) + """ + def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'): + if not redis_installed: + raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"') + + + aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0])) + if aioredis_version < (2,): + raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0') + self.redis = aioredis.Redis(host=host, port=port, db=db, password=password) + + self.prefix = prefix + #self.con = Redis(connection_pool=self.redis) -> use this when necessary + # + # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} + + async def get_record(self, key): + """ + Function to get record from database. + It has nothing to do with states. + Made for backend compatibility + """ + result = await self.redis.get(self.prefix+str(key)) + if result: return json.loads(result) + return + + async def set_record(self, key, value): + """ + Function to set record to database. + It has nothing to do with states. + Made for backend compatibility + """ + + await self.redis.set(self.prefix+str(key), json.dumps(value)) + return True + + async def delete_record(self, key): + """ + Function to delete record from database. + It has nothing to do with states. + Made for backend compatibility + """ + await self.redis.delete(self.prefix+str(key)) + return True + + async def set_state(self, chat_id, user_id, state): + """ + Set state for a particular user in a chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['state'] = state + else: + response[user_id] = {'state': state, 'data': {}} + else: + response = {user_id: {'state': state, 'data': {}}} + await self.set_record(chat_id, response) + + return True + + async def delete_state(self, chat_id, user_id): + """ + Delete state for a particular user in a chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + del response[user_id] + if user_id == str(chat_id): + await self.delete_record(chat_id) + return True + else: await self.set_record(chat_id, response) + return True + return False + + + async def get_value(self, chat_id, user_id, key): + """ + Get value for a data of a user in a chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + if key in response[user_id]['data']: + return response[user_id]['data'][key] + return None + + + async def get_state(self, chat_id, user_id): + """ + Get state of a user in a chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + return response[user_id]['state'] + + return None + + + async def get_data(self, chat_id, user_id): + """ + Get data of particular user in a particular chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + return response[user_id]['data'] + return None + + + async def reset_data(self, chat_id, user_id): + """ + Reset data of a user in a chat. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'] = {} + await self.set_record(chat_id, response) + return True + + + + + async def set_data(self, chat_id, user_id, key, value): + """ + Set data without interactive data. + """ + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'][key] = value + await self.set_record(chat_id, response) + return True + return False + + def get_interactive_data(self, chat_id, user_id): + """ + Get Data in interactive way. + You can use with() with this function. + """ + return StateContext(self, chat_id, user_id) + + async def save(self, chat_id, user_id, data): + response = await self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'] = dict(data, **response[user_id]['data']) + await self.set_record(chat_id, response) + return True + \ No newline at end of file diff --git a/telebot/asyncio_types.py b/telebot/asyncio_types.py new file mode 100644 index 0000000..9fe798c --- /dev/null +++ b/telebot/asyncio_types.py @@ -0,0 +1 @@ +# planned \ No newline at end of file diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index 0b95523..147596c 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -158,11 +158,21 @@ class StateFilter(AdvancedCustomFilter): key = 'state' def check(self, message, text): - if self.bot.current_states.current_state(message.from_user.id) is False: return False - elif text == '*': return True - elif type(text) is list: return self.bot.current_states.current_state(message.from_user.id) in text - return self.bot.current_states.current_state(message.from_user.id) == text - + if text == '*': return True + if message.chat.type == 'group': + group_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id) + if group_state == text: + return True + elif group_state in text and type(text) is list: + return True + + + else: + user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id) + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True class IsDigitFilter(SimpleCustomFilter): """ Filter to check whether the string is made up of only digits. diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index 1e67870..df4d37f 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -3,6 +3,11 @@ import pickle import threading from telebot import apihelper +try: + from redis import Redis + redis_installed = True +except: + redis_installed = False class HandlerBackend(object): @@ -116,7 +121,8 @@ class FileHandlerBackend(HandlerBackend): class RedisHandlerBackend(HandlerBackend): def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None): super(RedisHandlerBackend, self).__init__(handlers) - from redis import Redis + if not redis_installed: + raise Exception("Redis is not installed. Install it via 'pip install redis'") self.prefix = prefix self.redis = Redis(host, port, db, password) @@ -142,198 +148,4 @@ class RedisHandlerBackend(HandlerBackend): self.clear_handlers(handler_group_id) return handlers - -class StateMemory: - def __init__(self): - self._states = {} - - def add_state(self, chat_id, state): - """ - Add a state. - :param chat_id: - :param state: new state - """ - if chat_id in self._states: - - self._states[chat_id]['state'] = state - else: - self._states[chat_id] = {'state': state,'data': {}} - - def current_state(self, chat_id): - """Current state""" - if chat_id in self._states: return self._states[chat_id]['state'] - else: return False - - def delete_state(self, chat_id): - """Delete a state""" - self._states.pop(chat_id) - - def get_data(self, chat_id): - return self._states[chat_id]['data'] - - def set(self, chat_id, new_state): - """ - Set a new state for a user. - :param chat_id: - :param new_state: new_state of a user - """ - self.add_state(chat_id,new_state) - - def add_data(self, chat_id, key, value): - result = self._states[chat_id]['data'][key] = value - return result - - def finish(self, chat_id): - """ - Finish(delete) state of a user. - :param chat_id: - """ - self.delete_state(chat_id) - - def retrieve_data(self, chat_id): - """ - Save input text. - - Usage: - with bot.retrieve_data(message.chat.id) as data: - data['name'] = message.text - - Also, at the end of your 'Form' you can get the name: - data['name'] - """ - return StateContext(self, chat_id) - - -class StateFile: - """ - Class to save states in a file. - """ - def __init__(self, filename): - self.file_path = filename - - def add_state(self, chat_id, state): - """ - Add a state. - :param chat_id: - :param state: new state - """ - states_data = self.read_data() - if chat_id in states_data: - states_data[chat_id]['state'] = state - return self.save_data(states_data) - else: - states_data[chat_id] = {'state': state,'data': {}} - return self.save_data(states_data) - - def current_state(self, chat_id): - """Current state.""" - states_data = self.read_data() - if chat_id in states_data: return states_data[chat_id]['state'] - else: return False - - def delete_state(self, chat_id): - """Delete a state""" - states_data = self.read_data() - states_data.pop(chat_id) - self.save_data(states_data) - - def read_data(self): - """ - Read the data from file. - """ - file = open(self.file_path, 'rb') - states_data = pickle.load(file) - file.close() - return states_data - - def create_dir(self): - """ - Create directory .save-handlers. - """ - dirs = self.file_path.rsplit('/', maxsplit=1)[0] - os.makedirs(dirs, exist_ok=True) - if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: - pickle.dump({}, file) - - def save_data(self, new_data): - """ - Save data after editing. - :param new_data: - """ - with open(self.file_path, 'wb+') as state_file: - pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL) - return True - - def get_data(self, chat_id): - return self.read_data()[chat_id]['data'] - - def set(self, chat_id, new_state): - """ - Set a new state for a user. - :param chat_id: - :param new_state: new_state of a user - """ - self.add_state(chat_id,new_state) - - def add_data(self, chat_id, key, value): - states_data = self.read_data() - result = states_data[chat_id]['data'][key] = value - self.save_data(result) - return result - - def finish(self, chat_id): - """ - Finish(delete) state of a user. - :param chat_id: - """ - self.delete_state(chat_id) - - def retrieve_data(self, chat_id): - """ - Save input text. - - Usage: - with bot.retrieve_data(message.chat.id) as data: - data['name'] = message.text - - Also, at the end of your 'Form' you can get the name: - data['name'] - """ - return StateFileContext(self, chat_id) - - -class StateContext: - """ - Class for data. - """ - def __init__(self , obj: StateMemory, chat_id) -> None: - self.obj = obj - self.chat_id = chat_id - self.data = obj.get_data(chat_id) - - def __enter__(self): - return self.data - - def __exit__(self, exc_type, exc_val, exc_tb): - return - - -class StateFileContext: - """ - Class for data. - """ - def __init__(self , obj: StateFile, chat_id) -> None: - self.obj = obj - self.chat_id = chat_id - self.data = self.obj.get_data(self.chat_id) - - def __enter__(self): - return self.data - - def __exit__(self, exc_type, exc_val, exc_tb): - old_data = self.obj.read_data() - for i in self.data: - old_data[self.chat_id]['data'][i] = self.data.get(i) - self.obj.save_data(old_data) - return + \ No newline at end of file diff --git a/telebot/storage/__init__.py b/telebot/storage/__init__.py new file mode 100644 index 0000000..59e2b05 --- /dev/null +++ b/telebot/storage/__init__.py @@ -0,0 +1,13 @@ +from telebot.storage.memory_storage import StateMemoryStorage +from telebot.storage.redis_storage import StateRedisStorage +from telebot.storage.pickle_storage import StatePickleStorage +from telebot.storage.base_storage import StateContext,StateStorageBase + + + + + +__all__ = [ + 'StateStorageBase', 'StateContext', + 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' +] \ No newline at end of file diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py new file mode 100644 index 0000000..2ff2b8c --- /dev/null +++ b/telebot/storage/base_storage.py @@ -0,0 +1,65 @@ +import copy + +class StateStorageBase: + def __init__(self) -> None: + pass + + def set_data(self, chat_id, user_id, key, value): + """ + Set data for a user in a particular chat. + """ + raise NotImplementedError + + def get_data(self, chat_id, user_id): + """ + Get data for a user in a particular chat. + """ + raise NotImplementedError + + def set_state(self, chat_id, user_id, state): + """ + Set state for a particular user. + + ! Note that you should create a + record if it does not exist, and + if a record with state already exists, + you need to update a record. + """ + raise NotImplementedError + + def delete_state(self, chat_id, user_id): + """ + Delete state for a particular user. + """ + raise NotImplementedError + + def reset_data(self, chat_id, user_id): + """ + Reset data for a particular user in a chat. + """ + raise NotImplementedError + + def get_state(self, chat_id, user_id): + raise NotImplementedError + + def save(chat_id, user_id, data): + raise NotImplementedError + + + +class StateContext: + """ + Class for data. + """ + def __init__(self , obj, chat_id, user_id) -> None: + self.obj = obj + self.data = copy.deepcopy(obj.get_data(chat_id, user_id)) + self.chat_id = chat_id + self.user_id = user_id + + + def __enter__(self): + return self.data + + def __exit__(self, exc_type, exc_val, exc_tb): + return self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py new file mode 100644 index 0000000..3540ac5 --- /dev/null +++ b/telebot/storage/memory_storage.py @@ -0,0 +1,64 @@ +from telebot.storage.base_storage import StateStorageBase, StateContext + +class StateMemoryStorage(StateStorageBase): + def __init__(self) -> None: + self.data = {} + # + # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} + + + def set_state(self, chat_id, user_id, state): + if chat_id in self.data: + if user_id in self.data[chat_id]: + self.data[chat_id][user_id]['state'] = state + return True + else: + self.data[chat_id][user_id] = {'state': state, 'data': {}} + return True + self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + return True + + def delete_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + del self.data[chat_id][user_id] + if chat_id == user_id: + del self.data[chat_id] + + return True + + return False + + + def get_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['state'] + + return None + def get_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['data'] + + return None + + def reset_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'] = {} + return True + return False + + def set_data(self, chat_id, user_id, key, value): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'][key] = value + return True + raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + + def get_interactive_data(self, chat_id, user_id): + return StateContext(self, chat_id, user_id) + + def save(self, chat_id, user_id, data): + self.data[chat_id][user_id]['data'] = data \ No newline at end of file diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py new file mode 100644 index 0000000..988d454 --- /dev/null +++ b/telebot/storage/pickle_storage.py @@ -0,0 +1,112 @@ +from telebot.storage.base_storage import StateStorageBase, StateContext +import os + + +import pickle + + +class StatePickleStorage(StateStorageBase): + def __init__(self, file_path="./.state-save/states.pkl") -> None: + self.file_path = file_path + self.create_dir() + self.data = self.read() + + def convert_old_to_new(self): + """ + Use this function to convert old storage to new storage. + This function is for people who was using pickle storage + that was in version <=4.3.1. + """ + # old looks like: + # {1: {'state': 'start', 'data': {'name': 'John'}} + # we should update old version pickle to new. + # new looks like: + # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} + new_data = {} + for key, value in self.data.items(): + # this returns us id and dict with data and state + new_data[key] = {key: value} # convert this to new + # pass it to global data + self.data = new_data + self.update_data() # update data in file + + def create_dir(self): + """ + Create directory .save-handlers. + """ + dirs = self.file_path.rsplit('/', maxsplit=1)[0] + os.makedirs(dirs, exist_ok=True) + if not os.path.isfile(self.file_path): + with open(self.file_path,'wb') as file: + pickle.dump({}, file) + + def read(self): + file = open(self.file_path, 'rb') + data = pickle.load(file) + file.close() + return data + + def update_data(self): + file = open(self.file_path, 'wb+') + pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) + file.close() + + def set_state(self, chat_id, user_id, state): + if chat_id in self.data: + if user_id in self.data[chat_id]: + self.data[chat_id][user_id]['state'] = state + return True + else: + self.data[chat_id][user_id] = {'state': state, 'data': {}} + return True + self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + self.update_data() + return True + + def delete_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + del self.data[chat_id][user_id] + if chat_id == user_id: + del self.data[chat_id] + self.update_data() + return True + + return False + + + def get_state(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['state'] + + return None + def get_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + return self.data[chat_id][user_id]['data'] + + return None + + def reset_data(self, chat_id, user_id): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'] = {} + self.update_data() + return True + return False + + def set_data(self, chat_id, user_id, key, value): + if self.data.get(chat_id): + if self.data[chat_id].get(user_id): + self.data[chat_id][user_id]['data'][key] = value + self.update_data() + return True + raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + + def get_interactive_data(self, chat_id, user_id): + return StateContext(self, chat_id, user_id) + + def save(self, chat_id, user_id, data): + self.data[chat_id][user_id]['data'] = data + self.update_data() \ No newline at end of file diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py new file mode 100644 index 0000000..a8ba2c5 --- /dev/null +++ b/telebot/storage/redis_storage.py @@ -0,0 +1,176 @@ +from telebot.storage.base_storage import StateStorageBase, StateContext +import json + +redis_installed = True +try: + from redis import Redis, ConnectionPool + +except: + redis_installed = False + +class StateRedisStorage(StateStorageBase): + """ + This class is for Redis storage. + This will work only for states. + To use it, just pass this class to: + TeleBot(storage=StateRedisStorage()) + """ + def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'): + self.redis = ConnectionPool(host=host, port=port, db=db, password=password) + #self.con = Redis(connection_pool=self.redis) -> use this when necessary + # + # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} + self.prefix = prefix + if not redis_installed: + raise Exception("Redis is not installed. Install it via 'pip install redis'") + + def get_record(self, key): + """ + Function to get record from database. + It has nothing to do with states. + Made for backend compatibility + """ + connection = Redis(connection_pool=self.redis) + result = connection.get(self.prefix+str(key)) + connection.close() + if result: return json.loads(result) + return + + def set_record(self, key, value): + """ + Function to set record to database. + It has nothing to do with states. + Made for backend compatibility + """ + connection = Redis(connection_pool=self.redis) + connection.set(self.prefix+str(key), json.dumps(value)) + connection.close() + return True + + def delete_record(self, key): + """ + Function to delete record from database. + It has nothing to do with states. + Made for backend compatibility + """ + connection = Redis(connection_pool=self.redis) + connection.delete(self.prefix+str(key)) + connection.close() + return True + + def set_state(self, chat_id, user_id, state): + """ + Set state for a particular user in a chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['state'] = state + else: + response[user_id] = {'state': state, 'data': {}} + else: + response = {user_id: {'state': state, 'data': {}}} + self.set_record(chat_id, response) + + return True + + def delete_state(self, chat_id, user_id): + """ + Delete state for a particular user in a chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + del response[user_id] + if user_id == str(chat_id): + self.delete_record(chat_id) + return True + else: self.set_record(chat_id, response) + return True + return False + + + def get_value(self, chat_id, user_id, key): + """ + Get value for a data of a user in a chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + if key in response[user_id]['data']: + return response[user_id]['data'][key] + return None + + + def get_state(self, chat_id, user_id): + """ + Get state of a user in a chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + return response[user_id]['state'] + + return None + + + def get_data(self, chat_id, user_id): + """ + Get data of particular user in a particular chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + return response[user_id]['data'] + return None + + + def reset_data(self, chat_id, user_id): + """ + Reset data of a user in a chat. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'] = {} + self.set_record(chat_id, response) + return True + + + + + def set_data(self, chat_id, user_id, key, value): + """ + Set data without interactive data. + """ + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'][key] = value + self.set_record(chat_id, response) + return True + return False + + def get_interactive_data(self, chat_id, user_id): + """ + Get Data in interactive way. + You can use with() with this function. + """ + return StateContext(self, chat_id, user_id) + + def save(self, chat_id, user_id, data): + response = self.get_record(chat_id) + user_id = str(user_id) + if response: + if user_id in response: + response[user_id]['data'] = dict(data, **response[user_id]['data']) + self.set_record(chat_id, response) + return True + \ No newline at end of file