1
0
mirror of https://github.com/eternnoir/pyTelegramBotAPI.git synced 2023-08-10 21:12:57 +03:00

Merge pull request #1421 from coder2020official/master

RedisStorage, middleware fix, pass_bot parameter and more
This commit is contained in:
Badiboy 2022-01-24 21:25:16 +03:00 committed by GitHub
commit 2e9947277a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1274 additions and 562 deletions

View File

@ -1,15 +1,28 @@
import telebot import telebot
from telebot import asyncio_filters from telebot import asyncio_filters
from telebot.async_telebot import AsyncTeleBot from telebot.async_telebot import AsyncTeleBot
bot = AsyncTeleBot('TOKEN')
# list of storages, you can use any storage
from telebot.asyncio_storage import StateRedisStorage,StateMemoryStorage,StatePickleStorage
# new feature for states.
from telebot.asyncio_handler_backends import State, StatesGroup
# default state storage is statememorystorage
bot = AsyncTeleBot('TOKEN', state_storage=StateMemoryStorage())
# Just create different statesgroup
class MyStates(StatesGroup):
name = State() # statesgroup should contain states
surname = State()
age = State()
class MyStates: # set_state -> sets a new state
name = 1 # delete_state -> delets state if exists
surname = 2 # get_state -> returns state if exists
age = 3
@bot.message_handler(commands=['start']) @bot.message_handler(commands=['start'])
@ -17,7 +30,7 @@ async def start_ex(message):
""" """
Start command. Here we are starting state Start command. Here we are starting state
""" """
await bot.set_state(message.from_user.id, MyStates.name) await bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
await bot.send_message(message.chat.id, 'Hi, write me a name') await bot.send_message(message.chat.id, 'Hi, write me a name')
@ -28,39 +41,45 @@ async def any_state(message):
Cancel state Cancel state
""" """
await bot.send_message(message.chat.id, "Your state was cancelled.") await bot.send_message(message.chat.id, "Your state was cancelled.")
await bot.delete_state(message.from_user.id) await bot.delete_state(message.from_user.id, message.chat.id)
@bot.message_handler(state=MyStates.name) @bot.message_handler(state=MyStates.name)
async def name_get(message): async def name_get(message):
""" """
State 1. Will process when user's state is 1. State 1. Will process when user's state is MyStates.name.
""" """
await bot.send_message(message.chat.id, f'Now write me a surname') await bot.send_message(message.chat.id, f'Now write me a surname')
await bot.set_state(message.from_user.id, MyStates.surname) await bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
async with bot.retrieve_data(message.from_user.id) as data: async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
data['name'] = message.text data['name'] = message.text
@bot.message_handler(state=MyStates.surname) @bot.message_handler(state=MyStates.surname)
async def ask_age(message): async def ask_age(message):
""" """
State 2. Will process when user's state is 2. State 2. Will process when user's state is MyStates.surname.
""" """
await bot.send_message(message.chat.id, "What is your age?") await bot.send_message(message.chat.id, "What is your age?")
await bot.set_state(message.from_user.id, MyStates.age) await bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
async with bot.retrieve_data(message.from_user.id) as data: async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
data['surname'] = message.text data['surname'] = message.text
# result # result
@bot.message_handler(state=MyStates.age, is_digit=True) @bot.message_handler(state=MyStates.age, is_digit=True)
async def ready_for_answer(message): async def ready_for_answer(message):
async with bot.retrieve_data(message.from_user.id) as data: """
State 3. Will process when user's state is MyStates.age.
"""
async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
await bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html") await bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
await bot.delete_state(message.from_user.id) await bot.delete_state(message.from_user.id, message.chat.id)
#incorrect number #incorrect number
@bot.message_handler(state=MyStates.age, is_digit=False) @bot.message_handler(state=MyStates.age, is_digit=False)
async def age_incorrect(message): async def age_incorrect(message):
"""
Will process for wrong input when state is MyState.age
"""
await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
# register filters # register filters
@ -68,8 +87,6 @@ async def age_incorrect(message):
bot.add_custom_filter(asyncio_filters.StateFilter(bot)) bot.add_custom_filter(asyncio_filters.StateFilter(bot))
bot.add_custom_filter(asyncio_filters.IsDigitFilter()) bot.add_custom_filter(asyncio_filters.IsDigitFilter())
# set saving states into file.
bot.enable_saving_states() # you can delete this if you do not need to save states
import asyncio import asyncio
asyncio.run(bot.polling()) asyncio.run(bot.polling())

View File

@ -1,14 +1,39 @@
import telebot import telebot # telebot
from telebot import custom_filters from telebot import custom_filters
from telebot.handler_backends import State, StatesGroup #States
bot = telebot.TeleBot("") # States storage
from telebot.storage import StateRedisStorage, StatePickleStorage, StateMemoryStorage
class MyStates: # Beginning from version 4.4.0+, we support storages.
name = 1 # StateRedisStorage -> Redis-based storage.
surname = 2 # StatePickleStorage -> Pickle-based storage.
age = 3 # For redis, you will need to install redis.
# Pass host, db, password, or anything else,
# if you need to change config for redis.
# Pickle requires path. Default path is in folder .state-saves.
# If you were using older version of pytba for pickle,
# you need to migrate from old pickle to new by using
# StatePickleStorage().convert_old_to_new()
# Now, you can pass storage to bot.
state_storage = StateMemoryStorage() # you can init here another storage
bot = telebot.TeleBot("TOKEN",
state_storage=state_storage)
# States group.
class MyStates(StatesGroup):
# Just name variables differently
name = State() # creating instances of State class is enough from now
surname = State()
age = State()
@ -17,50 +42,56 @@ def start_ex(message):
""" """
Start command. Here we are starting state Start command. Here we are starting state
""" """
bot.set_state(message.from_user.id, MyStates.name) bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
bot.send_message(message.chat.id, 'Hi, write me a name') bot.send_message(message.chat.id, 'Hi, write me a name')
# Any state
@bot.message_handler(state="*", commands='cancel') @bot.message_handler(state="*", commands='cancel')
def any_state(message): def any_state(message):
""" """
Cancel state Cancel state
""" """
bot.send_message(message.chat.id, "Your state was cancelled.") bot.send_message(message.chat.id, "Your state was cancelled.")
bot.delete_state(message.from_user.id) bot.delete_state(message.from_user.id, message.chat.id)
@bot.message_handler(state=MyStates.name) @bot.message_handler(state=MyStates.name)
def name_get(message): def name_get(message):
""" """
State 1. Will process when user's state is 1. State 1. Will process when user's state is MyStates.name.
""" """
bot.send_message(message.chat.id, f'Now write me a surname') bot.send_message(message.chat.id, f'Now write me a surname')
bot.set_state(message.from_user.id, MyStates.surname) bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
with bot.retrieve_data(message.from_user.id) as data: with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
data['name'] = message.text data['name'] = message.text
@bot.message_handler(state=MyStates.surname) @bot.message_handler(state=MyStates.surname)
def ask_age(message): def ask_age(message):
""" """
State 2. Will process when user's state is 2. State 2. Will process when user's state is MyStates.surname.
""" """
bot.send_message(message.chat.id, "What is your age?") bot.send_message(message.chat.id, "What is your age?")
bot.set_state(message.from_user.id, MyStates.age) bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
with bot.retrieve_data(message.from_user.id) as data: with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
data['surname'] = message.text data['surname'] = message.text
# result # result
@bot.message_handler(state=MyStates.age, is_digit=True) @bot.message_handler(state=MyStates.age, is_digit=True)
def ready_for_answer(message): def ready_for_answer(message):
with bot.retrieve_data(message.from_user.id) as data: """
State 3. Will process when user's state is MyStates.age.
"""
with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html") bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
bot.delete_state(message.from_user.id) bot.delete_state(message.from_user.id, message.chat.id)
#incorrect number #incorrect number
@bot.message_handler(state=MyStates.age, is_digit=False) @bot.message_handler(state=MyStates.age, is_digit=False)
def age_incorrect(message): def age_incorrect(message):
"""
Wrong response for MyStates.age
"""
bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
# register filters # register filters
@ -68,7 +99,4 @@ def age_incorrect(message):
bot.add_custom_filter(custom_filters.StateFilter(bot)) bot.add_custom_filter(custom_filters.StateFilter(bot))
bot.add_custom_filter(custom_filters.IsDigitFilter()) bot.add_custom_filter(custom_filters.IsDigitFilter())
# set saving states into file.
bot.enable_saving_states() # you can delete this if you do not need to save states
bot.infinity_polling(skip_pending=True) bot.infinity_polling(skip_pending=True)

View File

@ -13,7 +13,8 @@ from typing import Any, Callable, List, Optional, Union
import telebot.util import telebot.util
import telebot.types import telebot.types
# storage
from telebot.storage import StatePickleStorage, StateMemoryStorage
logger = logging.getLogger('TeleBot') logger = logging.getLogger('TeleBot')
@ -28,7 +29,7 @@ logger.addHandler(console_output_handler)
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
from telebot import apihelper, util, types from telebot import apihelper, util, types
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, StateMemory, StateFile from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend
from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter
@ -148,7 +149,7 @@ class TeleBot:
def __init__( def __init__(
self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2, self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2,
next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0, next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0,
suppress_middleware_excepions=False suppress_middleware_excepions=False, state_storage=StateMemoryStorage()
): ):
""" """
:param token: bot API token :param token: bot API token
@ -193,7 +194,7 @@ class TeleBot:
self.custom_filters = {} self.custom_filters = {}
self.state_handlers = [] self.state_handlers = []
self.current_states = StateMemory() self.current_states = state_storage
if apihelper.ENABLE_MIDDLEWARE: if apihelper.ENABLE_MIDDLEWARE:
@ -251,7 +252,7 @@ class TeleBot:
:param filename: Filename of saving file :param filename: Filename of saving file
""" """
self.current_states = StateFile(filename=filename) self.current_states = StatePickleStorage(filename=filename)
self.current_states.create_dir() self.current_states.create_dir()
def enable_save_reply_handlers(self, delay=120, filename="./.handler-saves/reply.save"): def enable_save_reply_handlers(self, delay=120, filename="./.handler-saves/reply.save"):
@ -777,6 +778,13 @@ class TeleBot:
logger.info('Stopped polling.') logger.info('Stopped polling.')
def _exec_task(self, task, *args, **kwargs): def _exec_task(self, task, *args, **kwargs):
if kwargs.get('task_type') == 'handler':
pass_bot = kwargs.get('pass_bot')
kwargs.pop('pass_bot')
kwargs.pop('task_type')
if pass_bot:
kwargs['bot'] = self
if self.threaded: if self.threaded:
self.worker_pool.put(task, *args, **kwargs) self.worker_pool.put(task, *args, **kwargs)
else: else:
@ -2531,40 +2539,59 @@ class TeleBot:
chat_id = message.chat.id chat_id = message.chat.id
self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs) self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs)
def set_state(self, chat_id: int, state: Union[int, str]): def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None:
""" """
Sets a new state of a user. Sets a new state of a user.
:param chat_id: :param chat_id:
:param state: new state. can be string or integer. :param state: new state. can be string or integer.
""" """
self.current_states.add_state(chat_id, state) if chat_id is None:
chat_id = user_id
self.current_states.set_state(chat_id, user_id, state)
def delete_state(self, chat_id: int): def reset_data(self, user_id: int, chat_id: int=None):
"""
Reset data for a user in chat.
:param user_id:
:param chat_id:
"""
if chat_id is None:
chat_id = user_id
self.current_states.reset_data(chat_id, user_id)
def delete_state(self, user_id: int, chat_id: int=None) -> None:
""" """
Delete the current state of a user. Delete the current state of a user.
:param chat_id: :param chat_id:
:return: :return:
""" """
self.current_states.delete_state(chat_id) if chat_id is None:
chat_id = user_id
self.current_states.delete_state(chat_id, user_id)
def retrieve_data(self, chat_id: int): def retrieve_data(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]:
return self.current_states.retrieve_data(chat_id) if chat_id is None:
chat_id = user_id
return self.current_states.get_interactive_data(chat_id, user_id)
def get_state(self, chat_id: int): def get_state(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]:
""" """
Get current state of a user. Get current state of a user.
:param chat_id: :param chat_id:
:return: state of a user :return: state of a user
""" """
return self.current_states.current_state(chat_id) if chat_id is None:
chat_id = user_id
return self.current_states.get_state(chat_id, user_id)
def add_data(self, chat_id: int, **kwargs): def add_data(self, user_id: int, chat_id:int=None, **kwargs):
""" """
Add data to states. Add data to states.
:param chat_id: :param chat_id:
""" """
if chat_id is None:
chat_id = user_id
for key, value in kwargs.items(): for key, value in kwargs.items():
self.current_states.add_data(chat_id, key, value) self.current_states.set_data(chat_id, user_id, key, value)
def register_next_step_handler_by_chat_id( def register_next_step_handler_by_chat_id(
self, chat_id: Union[int, str], callback: Callable, *args, **kwargs) -> None: self, chat_id: Union[int, str], callback: Callable, *args, **kwargs) -> None:
@ -2632,7 +2659,7 @@ class TeleBot:
@staticmethod @staticmethod
def _build_handler_dict(handler, **filters): def _build_handler_dict(handler, pass_bot=False, **filters):
""" """
Builds a dictionary for a handler Builds a dictionary for a handler
:param handler: :param handler:
@ -2641,6 +2668,7 @@ class TeleBot:
""" """
return { return {
'function': handler, 'function': handler,
'pass_bot': pass_bot,
'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None} 'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None}
# Remove None values, they are skipped in _test_filter anyway # Remove None values, they are skipped in _test_filter anyway
#'filters': filters #'filters': filters
@ -2686,7 +2714,7 @@ class TeleBot:
:return: :return:
""" """
if not apihelper.ENABLE_MIDDLEWARE: if not apihelper.ENABLE_MIDDLEWARE:
raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE.") raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE before initialising TeleBot.")
if update_types: if update_types:
for update_type in update_types: for update_type in update_types:
@ -2694,6 +2722,27 @@ class TeleBot:
else: else:
self.default_middleware_handlers.append(handler) self.default_middleware_handlers.append(handler)
# function register_middleware_handler
def register_middleware_handler(self, callback, update_types=None):
"""
Middleware handler decorator.
This function will create a decorator that can be used to decorate functions that must be handled as middlewares before entering any other
message handlers
But, be careful and check type of the update inside the handler if more than one update_type is given
Example:
bot = TeleBot('TOKEN')
bot.register_middleware_handler(print_channel_post_text, update_types=['channel_post', 'edited_channel_post'])
:param update_types: Optional list of update types that can be passed into the middleware handler.
"""
self.add_middleware_handler(callback, update_types)
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs): def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
""" """
Message handler decorator. Message handler decorator.
@ -2766,7 +2815,7 @@ class TeleBot:
""" """
self.message_handlers.append(handler_dict) self.message_handlers.append(handler_dict)
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
""" """
Registers message handler. Registers message handler.
:param callback: function to be called :param callback: function to be called
@ -2775,6 +2824,7 @@ class TeleBot:
:param regexp: :param regexp:
:param func: :param func:
:param chat_types: True for private chat :param chat_types: True for private chat
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
if isinstance(commands, str): if isinstance(commands, str):
@ -2791,6 +2841,7 @@ class TeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_message_handler(handler_dict) self.add_message_handler(handler_dict)
@ -2838,7 +2889,7 @@ class TeleBot:
""" """
self.edited_message_handlers.append(handler_dict) self.edited_message_handlers.append(handler_dict)
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
""" """
Registers edited message handler. Registers edited message handler.
:param callback: function to be called :param callback: function to be called
@ -2847,6 +2898,7 @@ class TeleBot:
:param regexp: :param regexp:
:param func: :param func:
:param chat_types: True for private chat :param chat_types: True for private chat
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
if isinstance(commands, str): if isinstance(commands, str):
@ -2863,6 +2915,7 @@ class TeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_edited_message_handler(handler_dict) self.add_edited_message_handler(handler_dict)
@ -2908,7 +2961,7 @@ class TeleBot:
""" """
self.channel_post_handlers.append(handler_dict) self.channel_post_handlers.append(handler_dict)
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
""" """
Registers channel post message handler. Registers channel post message handler.
:param callback: function to be called :param callback: function to be called
@ -2916,6 +2969,7 @@ class TeleBot:
:param commands: list of commands :param commands: list of commands
:param regexp: :param regexp:
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
if isinstance(commands, str): if isinstance(commands, str):
@ -2931,6 +2985,7 @@ class TeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_channel_post_handler(handler_dict) self.add_channel_post_handler(handler_dict)
@ -2975,7 +3030,7 @@ class TeleBot:
""" """
self.edited_channel_post_handlers.append(handler_dict) self.edited_channel_post_handlers.append(handler_dict)
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
""" """
Registers edited channel post message handler. Registers edited channel post message handler.
:param callback: function to be called :param callback: function to be called
@ -2983,6 +3038,7 @@ class TeleBot:
:param commands: list of commands :param commands: list of commands
:param regexp: :param regexp:
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
if isinstance(commands, str): if isinstance(commands, str):
@ -2998,6 +3054,7 @@ class TeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_edited_channel_post_handler(handler_dict) self.add_edited_channel_post_handler(handler_dict)
@ -3024,14 +3081,15 @@ class TeleBot:
""" """
self.inline_handlers.append(handler_dict) self.inline_handlers.append(handler_dict)
def register_inline_handler(self, callback, func, **kwargs): def register_inline_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers inline handler. Registers inline handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_inline_handler(handler_dict) self.add_inline_handler(handler_dict)
def chosen_inline_handler(self, func, **kwargs): def chosen_inline_handler(self, func, **kwargs):
@ -3057,14 +3115,15 @@ class TeleBot:
""" """
self.chosen_inline_handlers.append(handler_dict) self.chosen_inline_handlers.append(handler_dict)
def register_chosen_inline_handler(self, callback, func, **kwargs): def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers chosen inline handler. Registers chosen inline handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chosen_inline_handler(handler_dict) self.add_chosen_inline_handler(handler_dict)
def callback_query_handler(self, func, **kwargs): def callback_query_handler(self, func, **kwargs):
@ -3090,14 +3149,15 @@ class TeleBot:
""" """
self.callback_query_handlers.append(handler_dict) self.callback_query_handlers.append(handler_dict)
def register_callback_query_handler(self, callback, func, **kwargs): def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers callback query handler.. Registers callback query handler..
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_callback_query_handler(handler_dict) self.add_callback_query_handler(handler_dict)
def shipping_query_handler(self, func, **kwargs): def shipping_query_handler(self, func, **kwargs):
@ -3123,14 +3183,15 @@ class TeleBot:
""" """
self.shipping_query_handlers.append(handler_dict) self.shipping_query_handlers.append(handler_dict)
def register_shipping_query_handler(self, callback, func, **kwargs): def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers shipping query handler. Registers shipping query handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_shipping_query_handler(handler_dict) self.add_shipping_query_handler(handler_dict)
def pre_checkout_query_handler(self, func, **kwargs): def pre_checkout_query_handler(self, func, **kwargs):
@ -3156,14 +3217,15 @@ class TeleBot:
""" """
self.pre_checkout_query_handlers.append(handler_dict) self.pre_checkout_query_handlers.append(handler_dict)
def register_pre_checkout_query_handler(self, callback, func, **kwargs): def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers pre-checkout request handler. Registers pre-checkout request handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_pre_checkout_query_handler(handler_dict) self.add_pre_checkout_query_handler(handler_dict)
def poll_handler(self, func, **kwargs): def poll_handler(self, func, **kwargs):
@ -3189,14 +3251,15 @@ class TeleBot:
""" """
self.poll_handlers.append(handler_dict) self.poll_handlers.append(handler_dict)
def register_poll_handler(self, callback, func, **kwargs): def register_poll_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers poll handler. Registers poll handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_poll_handler(handler_dict) self.add_poll_handler(handler_dict)
def poll_answer_handler(self, func=None, **kwargs): def poll_answer_handler(self, func=None, **kwargs):
@ -3222,14 +3285,15 @@ class TeleBot:
""" """
self.poll_answer_handlers.append(handler_dict) self.poll_answer_handlers.append(handler_dict)
def register_poll_answer_handler(self, callback, func, **kwargs): def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers poll answer handler. Registers poll answer handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_poll_answer_handler(handler_dict) self.add_poll_answer_handler(handler_dict)
def my_chat_member_handler(self, func=None, **kwargs): def my_chat_member_handler(self, func=None, **kwargs):
@ -3255,14 +3319,15 @@ class TeleBot:
""" """
self.my_chat_member_handlers.append(handler_dict) self.my_chat_member_handlers.append(handler_dict)
def register_my_chat_member_handler(self, callback, func=None, **kwargs): def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers my chat member handler. Registers my chat member handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_my_chat_member_handler(handler_dict) self.add_my_chat_member_handler(handler_dict)
def chat_member_handler(self, func=None, **kwargs): def chat_member_handler(self, func=None, **kwargs):
@ -3288,14 +3353,15 @@ class TeleBot:
""" """
self.chat_member_handlers.append(handler_dict) self.chat_member_handlers.append(handler_dict)
def register_chat_member_handler(self, callback, func=None, **kwargs): def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers chat member handler. Registers chat member handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chat_member_handler(handler_dict) self.add_chat_member_handler(handler_dict)
def chat_join_request_handler(self, func=None, **kwargs): def chat_join_request_handler(self, func=None, **kwargs):
@ -3321,14 +3387,15 @@ class TeleBot:
""" """
self.chat_join_request_handlers.append(handler_dict) self.chat_join_request_handlers.append(handler_dict)
def register_chat_join_request_handler(self, callback, func=None, **kwargs): def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers chat join request handler. Registers chat join request handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:param pass_bot: Pass TeleBot to handler.
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chat_join_request_handler(handler_dict) self.add_chat_join_request_handler(handler_dict)
def _test_message_handler(self, message_handler, message): def _test_message_handler(self, message_handler, message):
@ -3409,7 +3476,7 @@ class TeleBot:
for message in new_messages: for message in new_messages:
for message_handler in handlers: for message_handler in handlers:
if self._test_message_handler(message_handler, message): if self._test_message_handler(message_handler, message):
self._exec_task(message_handler['function'], message) self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler')
break break

View File

@ -13,6 +13,9 @@ import telebot.util
import telebot.types import telebot.types
# storages
from telebot.asyncio_storage import StateMemoryStorage, StatePickleStorage
from inspect import signature from inspect import signature
from telebot import logger from telebot import logger
@ -161,7 +164,7 @@ class AsyncTeleBot:
""" """
def __init__(self, token: str, parse_mode: Optional[str]=None, offset=None, def __init__(self, token: str, parse_mode: Optional[str]=None, offset=None,
exception_handler=None) -> None: # TODO: ADD TYPEHINTS exception_handler=None, state_storage=StateMemoryStorage()) -> None: # TODO: ADD TYPEHINTS
self.token = token self.token = token
self.offset = offset self.offset = offset
@ -190,12 +193,13 @@ class AsyncTeleBot:
self.custom_filters = {} self.custom_filters = {}
self.state_handlers = [] self.state_handlers = []
self.current_states = asyncio_handler_backends.StateMemory() self.current_states = state_storage
self.middlewares = [] self.middlewares = []
async def close_session(self):
await asyncio_helper.session_manager.session.close()
async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None, async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None,
timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> List[types.Update]: timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> List[types.Update]:
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout) json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout)
@ -299,7 +303,7 @@ class AsyncTeleBot:
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout) updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
if updates: if updates:
self.offset = updates[-1].update_id + 1 self.offset = updates[-1].update_id + 1
self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates asyncio.create_task(self.process_new_updates(updates)) # Seperate task for processing updates
if interval: await asyncio.sleep(interval) if interval: await asyncio.sleep(interval)
except KeyboardInterrupt: except KeyboardInterrupt:
@ -322,6 +326,8 @@ class AsyncTeleBot:
continue continue
else: else:
break break
except KeyboardInterrupt:
return
except Exception as e: except Exception as e:
logger.error('Cause exception while getting updates.') logger.error('Cause exception while getting updates.')
if non_stop: if non_stop:
@ -333,6 +339,7 @@ class AsyncTeleBot:
finally: finally:
self._polling = False self._polling = False
await self.close_session()
logger.warning('Polling is stopped.') logger.warning('Polling is stopped.')
@ -346,31 +353,48 @@ class AsyncTeleBot:
:param messages: :param messages:
:return: :return:
""" """
tasks = []
for message in messages: for message in messages:
middleware = await self.process_middlewares(message, update_type) middleware = await self.process_middlewares(message, update_type)
self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware)) tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware))
asyncio.gather(*tasks)
async def _run_middlewares_and_handlers(self, handlers, message, middleware): async def _run_middlewares_and_handlers(self, handlers, message, middleware):
handler_error = None handler_error = None
data = {} data = {}
for message_handler in handlers: process_handler = True
process_update = await self._test_message_handler(message_handler, message)
if not process_update:
continue
elif process_update:
if middleware: if middleware:
middleware_result = await middleware.pre_process(message, data) middleware_result = await middleware.pre_process(message, data)
if isinstance(middleware_result, SkipHandler): if isinstance(middleware_result, SkipHandler):
await middleware.post_process(message, data, handler_error) await middleware.post_process(message, data, handler_error)
break process_handler = False
if isinstance(middleware_result, CancelUpdate): if isinstance(middleware_result, CancelUpdate):
return return
for handler in handlers:
if not process_handler:
break
process_update = await self._test_message_handler(handler, message)
if not process_update:
continue
elif process_update:
try: try:
if "data" in signature(message_handler['function']).parameters: params = []
await message_handler['function'](message, data)
else: for i in signature(handler['function']).parameters:
await message_handler['function'](message) params.append(i)
if len(params) == 1:
await handler['function'](message)
break
if params[1] == 'data' and handler.get('pass_bot') is True:
await handler['function'](message, data, self)
break
elif params[1] == 'data' and handler.get('pass_bot') is False:
await handler['function'](message, data)
break
elif params[1] != 'data' and handler.get('pass_bot') is True:
await handler['function'](message, self)
break break
except Exception as e: except Exception as e:
handler_error = e handler_error = e
@ -687,7 +711,7 @@ class AsyncTeleBot:
""" """
self.message_handlers.append(handler_dict) self.message_handlers.append(handler_dict)
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
""" """
Registers message handler. Registers message handler.
:param callback: function to be called :param callback: function to be called
@ -696,8 +720,11 @@ class AsyncTeleBot:
:param regexp: :param regexp:
:param func: :param func:
:param chat_types: True for private chat :param chat_types: True for private chat
:param pass_bot: True if you want to get TeleBot instance in your handler
:return: decorated function :return: decorated function
""" """
if content_types is None:
content_types = ["text"]
if isinstance(commands, str): if isinstance(commands, str):
logger.warning("register_message_handler: 'commands' filter should be List of strings (commands), not string.") logger.warning("register_message_handler: 'commands' filter should be List of strings (commands), not string.")
commands = [commands] commands = [commands]
@ -712,6 +739,7 @@ class AsyncTeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_message_handler(handler_dict) self.add_message_handler(handler_dict)
@ -759,7 +787,7 @@ class AsyncTeleBot:
""" """
self.edited_message_handlers.append(handler_dict) self.edited_message_handlers.append(handler_dict)
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs): def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
""" """
Registers edited message handler. Registers edited message handler.
:param callback: function to be called :param callback: function to be called
@ -784,6 +812,7 @@ class AsyncTeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_edited_message_handler(handler_dict) self.add_edited_message_handler(handler_dict)
@ -829,7 +858,7 @@ class AsyncTeleBot:
""" """
self.channel_post_handlers.append(handler_dict) self.channel_post_handlers.append(handler_dict)
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
""" """
Registers channel post message handler. Registers channel post message handler.
:param callback: function to be called :param callback: function to be called
@ -852,6 +881,7 @@ class AsyncTeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_channel_post_handler(handler_dict) self.add_channel_post_handler(handler_dict)
@ -896,7 +926,7 @@ class AsyncTeleBot:
""" """
self.edited_channel_post_handlers.append(handler_dict) self.edited_channel_post_handlers.append(handler_dict)
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs): def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
""" """
Registers edited channel post message handler. Registers edited channel post message handler.
:param callback: function to be called :param callback: function to be called
@ -919,6 +949,7 @@ class AsyncTeleBot:
commands=commands, commands=commands,
regexp=regexp, regexp=regexp,
func=func, func=func,
pass_bot=pass_bot,
**kwargs) **kwargs)
self.add_edited_channel_post_handler(handler_dict) self.add_edited_channel_post_handler(handler_dict)
@ -945,14 +976,14 @@ class AsyncTeleBot:
""" """
self.inline_handlers.append(handler_dict) self.inline_handlers.append(handler_dict)
def register_inline_handler(self, callback, func, **kwargs): def register_inline_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers inline handler. Registers inline handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_inline_handler(handler_dict) self.add_inline_handler(handler_dict)
def chosen_inline_handler(self, func, **kwargs): def chosen_inline_handler(self, func, **kwargs):
@ -978,14 +1009,14 @@ class AsyncTeleBot:
""" """
self.chosen_inline_handlers.append(handler_dict) self.chosen_inline_handlers.append(handler_dict)
def register_chosen_inline_handler(self, callback, func, **kwargs): def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers chosen inline handler. Registers chosen inline handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chosen_inline_handler(handler_dict) self.add_chosen_inline_handler(handler_dict)
def callback_query_handler(self, func, **kwargs): def callback_query_handler(self, func, **kwargs):
@ -1011,14 +1042,14 @@ class AsyncTeleBot:
""" """
self.callback_query_handlers.append(handler_dict) self.callback_query_handlers.append(handler_dict)
def register_callback_query_handler(self, callback, func, **kwargs): def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers callback query handler.. Registers callback query handler..
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_callback_query_handler(handler_dict) self.add_callback_query_handler(handler_dict)
def shipping_query_handler(self, func, **kwargs): def shipping_query_handler(self, func, **kwargs):
@ -1044,14 +1075,14 @@ class AsyncTeleBot:
""" """
self.shipping_query_handlers.append(handler_dict) self.shipping_query_handlers.append(handler_dict)
def register_shipping_query_handler(self, callback, func, **kwargs): def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers shipping query handler. Registers shipping query handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_shipping_query_handler(handler_dict) self.add_shipping_query_handler(handler_dict)
def pre_checkout_query_handler(self, func, **kwargs): def pre_checkout_query_handler(self, func, **kwargs):
@ -1077,14 +1108,14 @@ class AsyncTeleBot:
""" """
self.pre_checkout_query_handlers.append(handler_dict) self.pre_checkout_query_handlers.append(handler_dict)
def register_pre_checkout_query_handler(self, callback, func, **kwargs): def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers pre-checkout request handler. Registers pre-checkout request handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_pre_checkout_query_handler(handler_dict) self.add_pre_checkout_query_handler(handler_dict)
def poll_handler(self, func, **kwargs): def poll_handler(self, func, **kwargs):
@ -1110,14 +1141,14 @@ class AsyncTeleBot:
""" """
self.poll_handlers.append(handler_dict) self.poll_handlers.append(handler_dict)
def register_poll_handler(self, callback, func, **kwargs): def register_poll_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers poll handler. Registers poll handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_poll_handler(handler_dict) self.add_poll_handler(handler_dict)
def poll_answer_handler(self, func=None, **kwargs): def poll_answer_handler(self, func=None, **kwargs):
@ -1143,14 +1174,14 @@ class AsyncTeleBot:
""" """
self.poll_answer_handlers.append(handler_dict) self.poll_answer_handlers.append(handler_dict)
def register_poll_answer_handler(self, callback, func, **kwargs): def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs):
""" """
Registers poll answer handler. Registers poll answer handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_poll_answer_handler(handler_dict) self.add_poll_answer_handler(handler_dict)
def my_chat_member_handler(self, func=None, **kwargs): def my_chat_member_handler(self, func=None, **kwargs):
@ -1176,14 +1207,14 @@ class AsyncTeleBot:
""" """
self.my_chat_member_handlers.append(handler_dict) self.my_chat_member_handlers.append(handler_dict)
def register_my_chat_member_handler(self, callback, func=None, **kwargs): def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers my chat member handler. Registers my chat member handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_my_chat_member_handler(handler_dict) self.add_my_chat_member_handler(handler_dict)
def chat_member_handler(self, func=None, **kwargs): def chat_member_handler(self, func=None, **kwargs):
@ -1209,14 +1240,14 @@ class AsyncTeleBot:
""" """
self.chat_member_handlers.append(handler_dict) self.chat_member_handlers.append(handler_dict)
def register_chat_member_handler(self, callback, func=None, **kwargs): def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers chat member handler. Registers chat member handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chat_member_handler(handler_dict) self.add_chat_member_handler(handler_dict)
def chat_join_request_handler(self, func=None, **kwargs): def chat_join_request_handler(self, func=None, **kwargs):
@ -1242,18 +1273,18 @@ class AsyncTeleBot:
""" """
self.chat_join_request_handlers.append(handler_dict) self.chat_join_request_handlers.append(handler_dict)
def register_chat_join_request_handler(self, callback, func=None, **kwargs): def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs):
""" """
Registers chat join request handler. Registers chat join request handler.
:param callback: function to be called :param callback: function to be called
:param func: :param func:
:return: decorated function :return: decorated function
""" """
handler_dict = self._build_handler_dict(callback, func=func, **kwargs) handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
self.add_chat_join_request_handler(handler_dict) self.add_chat_join_request_handler(handler_dict)
@staticmethod @staticmethod
def _build_handler_dict(handler, **filters): def _build_handler_dict(handler, pass_bot=False, **filters):
""" """
Builds a dictionary for a handler Builds a dictionary for a handler
:param handler: :param handler:
@ -1262,6 +1293,7 @@ class AsyncTeleBot:
""" """
return { return {
'function': handler, 'function': handler,
'pass_bot': pass_bot,
'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None} 'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None}
# Remove None values, they are skipped in _test_filter anyway # Remove None values, they are skipped in _test_filter anyway
#'filters': filters #'filters': filters
@ -1324,8 +1356,7 @@ class AsyncTeleBot:
:param filename: Filename of saving file :param filename: Filename of saving file
""" """
self.current_states = asyncio_handler_backends.StateFile(filename=filename) self.current_states = StatePickleStorage(file_path=filename)
self.current_states.create_dir()
async def set_webhook(self, url=None, certificate=None, max_connections=None, allowed_updates=None, ip_address=None, async def set_webhook(self, url=None, certificate=None, max_connections=None, allowed_updates=None, ip_address=None,
drop_pending_updates = None, timeout=None): drop_pending_updates = None, timeout=None):
@ -1356,6 +1387,8 @@ class AsyncTeleBot:
return await asyncio_helper.set_webhook(self.token, url, certificate, max_connections, allowed_updates, ip_address, return await asyncio_helper.set_webhook(self.token, url, certificate, max_connections, allowed_updates, ip_address,
drop_pending_updates, timeout) drop_pending_updates, timeout)
async def delete_webhook(self, drop_pending_updates=None, timeout=None): async def delete_webhook(self, drop_pending_updates=None, timeout=None):
""" """
Use this method to remove webhook integration if you decide to switch back to getUpdates. Use this method to remove webhook integration if you decide to switch back to getUpdates.
@ -1366,6 +1399,12 @@ class AsyncTeleBot:
""" """
return await asyncio_helper.delete_webhook(self.token, drop_pending_updates, timeout) return await asyncio_helper.delete_webhook(self.token, drop_pending_updates, timeout)
async def remove_webhook(self):
"""
Alternative for delete_webhook but uses set_webhook
"""
self.set_webhook()
async def get_webhook_info(self, timeout=None): async def get_webhook_info(self, timeout=None):
""" """
Use this method to get current webhook status. Requires no parameters. Use this method to get current webhook status. Requires no parameters.
@ -2443,8 +2482,8 @@ class AsyncTeleBot:
""" """
return await asyncio_helper.delete_chat_photo(self.token, chat_id) return await asyncio_helper.delete_chat_photo(self.token, chat_id)
async def get_my_commands(self, scope: Optional[types.BotCommandScope]=None, async def get_my_commands(self, scope: Optional[types.BotCommandScope],
language_code: Optional[str]=None) -> List[types.BotCommand]: language_code: Optional[str]) -> List[types.BotCommand]:
""" """
Use this method to get the current list of the bot's commands. Use this method to get the current list of the bot's commands.
Returns List of BotCommand on success. Returns List of BotCommand on success.
@ -3019,37 +3058,57 @@ class AsyncTeleBot:
return await asyncio_helper.delete_sticker_from_set(self.token, sticker) return await asyncio_helper.delete_sticker_from_set(self.token, sticker)
async def set_state(self, chat_id, state): async def set_state(self, user_id: int, state: str, chat_id: int=None):
""" """
Sets a new state of a user. Sets a new state of a user.
:param chat_id: :param chat_id:
:param state: new state. can be string or integer. :param state: new state. can be string or integer.
""" """
await self.current_states.add_state(chat_id, state) if not chat_id:
chat_id = user_id
await self.current_states.set_state(chat_id, user_id, state)
async def delete_state(self, chat_id): async def reset_data(self, user_id: int, chat_id: int=None):
"""
Reset data for a user in chat.
:param user_id:
:param chat_id:
"""
if chat_id is None:
chat_id = user_id
await self.current_states.reset_data(chat_id, user_id)
async def delete_state(self, user_id: int, chat_id:int=None):
""" """
Delete the current state of a user. Delete the current state of a user.
:param chat_id: :param chat_id:
:return: :return:
""" """
await self.current_states.delete_state(chat_id) if not chat_id:
chat_id = user_id
await self.current_states.delete_state(chat_id, user_id)
def retrieve_data(self, chat_id): def retrieve_data(self, user_id: int, chat_id: int=None):
return self.current_states.retrieve_data(chat_id) if not chat_id:
chat_id = user_id
return self.current_states.get_interactive_data(chat_id, user_id)
async def get_state(self, chat_id): async def get_state(self, user_id, chat_id: int=None):
""" """
Get current state of a user. Get current state of a user.
:param chat_id: :param chat_id:
:return: state of a user :return: state of a user
""" """
return await self.current_states.current_state(chat_id) if not chat_id:
chat_id = user_id
return await self.current_states.get_state(chat_id, user_id)
async def add_data(self, chat_id, **kwargs): async def add_data(self, user_id: int, chat_id: int=None, **kwargs):
""" """
Add data to states. Add data to states.
:param chat_id: :param chat_id:
""" """
if not chat_id:
chat_id = user_id
for key, value in kwargs.items(): for key, value in kwargs.items():
await self.current_states.add_data(chat_id, key, value) await self.current_states.set_data(chat_id, user_id, key, value)

View File

@ -159,11 +159,21 @@ class StateFilter(AdvancedCustomFilter):
key = 'state' key = 'state'
async def check(self, message, text): async def check(self, message, text):
result = await self.bot.current_states.current_state(message.from_user.id) if text == '*': return True
if result is False: return False if message.chat.type == 'group':
elif text == '*': return True group_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id)
elif type(text) is list: return result in text if group_state == text:
return result == text return True
elif group_state in text and type(text) is list:
return True
else:
user_state = await self.bot.current_states.get_state(message.chat.id,message.from_user.id)
if user_state == text:
return True
elif type(text) is list and user_state in text:
return True
class IsDigitFilter(SimpleCustomFilter): class IsDigitFilter(SimpleCustomFilter):
""" """

View File

@ -3,206 +3,6 @@ import pickle
class StateMemory:
def __init__(self):
self._states = {}
async def add_state(self, chat_id, state):
"""
Add a state.
:param chat_id:
:param state: new state
"""
if chat_id in self._states:
self._states[chat_id]['state'] = state
else:
self._states[chat_id] = {'state': state,'data': {}}
async def current_state(self, chat_id):
"""Current state"""
if chat_id in self._states: return self._states[chat_id]['state']
else: return False
async def delete_state(self, chat_id):
"""Delete a state"""
self._states.pop(chat_id)
def get_data(self, chat_id):
return self._states[chat_id]['data']
async def set(self, chat_id, new_state):
"""
Set a new state for a user.
:param chat_id:
:param new_state: new_state of a user
"""
await self.add_state(chat_id,new_state)
async def add_data(self, chat_id, key, value):
result = self._states[chat_id]['data'][key] = value
return result
async def finish(self, chat_id):
"""
Finish(delete) state of a user.
:param chat_id:
"""
await self.delete_state(chat_id)
def retrieve_data(self, chat_id):
"""
Save input text.
Usage:
with bot.retrieve_data(message.chat.id) as data:
data['name'] = message.text
Also, at the end of your 'Form' you can get the name:
data['name']
"""
return StateContext(self, chat_id)
class StateFile:
"""
Class to save states in a file.
"""
def __init__(self, filename):
self.file_path = filename
async def add_state(self, chat_id, state):
"""
Add a state.
:param chat_id:
:param state: new state
"""
states_data = self.read_data()
if chat_id in states_data:
states_data[chat_id]['state'] = state
return await self.save_data(states_data)
else:
states_data[chat_id] = {'state': state,'data': {}}
return await self.save_data(states_data)
async def current_state(self, chat_id):
"""Current state."""
states_data = self.read_data()
if chat_id in states_data: return states_data[chat_id]['state']
else: return False
async def delete_state(self, chat_id):
"""Delete a state"""
states_data = self.read_data()
states_data.pop(chat_id)
await self.save_data(states_data)
def read_data(self):
"""
Read the data from file.
"""
file = open(self.file_path, 'rb')
states_data = pickle.load(file)
file.close()
return states_data
def create_dir(self):
"""
Create directory .save-handlers.
"""
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
with open(self.file_path,'wb') as file:
pickle.dump({}, file)
async def save_data(self, new_data):
"""
Save data after editing.
:param new_data:
"""
with open(self.file_path, 'wb+') as state_file:
pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL)
return True
def get_data(self, chat_id):
return self.read_data()[chat_id]['data']
async def set(self, chat_id, new_state):
"""
Set a new state for a user.
:param chat_id:
:param new_state: new_state of a user
"""
await self.add_state(chat_id,new_state)
async def add_data(self, chat_id, key, value):
states_data = self.read_data()
result = states_data[chat_id]['data'][key] = value
await self.save_data(result)
return result
async def finish(self, chat_id):
"""
Finish(delete) state of a user.
:param chat_id:
"""
await self.delete_state(chat_id)
def retrieve_data(self, chat_id):
"""
Save input text.
Usage:
with bot.retrieve_data(message.chat.id) as data:
data['name'] = message.text
Also, at the end of your 'Form' you can get the name:
data['name']
"""
return StateFileContext(self, chat_id)
class StateContext:
"""
Class for data.
"""
def __init__(self , obj: StateMemory, chat_id) -> None:
self.obj = obj
self.chat_id = chat_id
self.data = obj.get_data(chat_id)
async def __aenter__(self):
return self.data
async def __aexit__(self, exc_type, exc_val, exc_tb):
return
class StateFileContext:
"""
Class for data.
"""
def __init__(self , obj: StateFile, chat_id) -> None:
self.obj = obj
self.chat_id = chat_id
self.data = None
async def __aenter__(self):
self.data = self.obj.get_data(self.chat_id)
return self.data
async def __aexit__(self, exc_type, exc_val, exc_tb):
old_data = self.obj.read_data()
for i in self.data:
old_data[self.chat_id]['data'][i] = self.data.get(i)
await self.obj.save_data(old_data)
return
class BaseMiddleware: class BaseMiddleware:
""" """
Base class for middleware. Base class for middleware.
@ -217,3 +17,19 @@ class BaseMiddleware:
async def post_process(self, message, data, exception): async def post_process(self, message, data, exception):
raise NotImplementedError raise NotImplementedError
class State:
def __init__(self) -> None:
self.name = None
def __str__(self) -> str:
return self.name
class StatesGroup:
def __init_subclass__(cls) -> None:
# print all variables of a subclass
for name, value in cls.__dict__.items():
if not name.startswith('__') and not callable(value) and isinstance(value, State):
# change value of that variable
value.name = ':'.join((cls.__name__, name))

View File

@ -12,16 +12,8 @@ API_URL = 'https://api.telegram.org/bot{0}/{1}'
from datetime import datetime from datetime import datetime
import telebot import telebot
from telebot import util from telebot import util, logger
class SessionBase:
def __init__(self) -> None:
self.session = None
async def _get_new_session(self):
self.session = aiohttp.ClientSession()
return self.session
session_manager = SessionBase()
proxy = None proxy = None
session = None session = None
@ -36,6 +28,29 @@ REQUEST_TIMEOUT = 10
MAX_RETRIES = 3 MAX_RETRIES = 3
logger = telebot.logger logger = telebot.logger
REQUEST_LIMIT = 50
class SessionManager:
def __init__(self) -> None:
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT))
async def create_session(self):
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT))
return self.session
async def get_session(self):
if self.session.closed:
self.session = await self.create_session()
if not self.session._loop.is_running():
await self.session.close()
self.session = await self.create_session()
return self.session
session_manager = SessionManager()
async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None): async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None):
params = prepare_data(params, files) params = prepare_data(params, files)
if request_timeout is None: if request_timeout is None:
@ -43,19 +58,21 @@ async def _process_request(token, url, method='get', params=None, files=None, re
timeout = aiohttp.ClientTimeout(total=request_timeout) timeout = aiohttp.ClientTimeout(total=request_timeout)
got_result = False got_result = False
current_try=0 current_try=0
async with await session_manager._get_new_session() as session: session = await session_manager.get_session()
while not got_result and current_try<MAX_RETRIES-1: while not got_result and current_try<MAX_RETRIES-1:
current_try +=1 current_try +=1
try: try:
response = await session.request(method=method, url=API_URL.format(token, url), data=params, timeout=timeout) async with session.request(method=method, url=API_URL.format(token, url), data=params, timeout=timeout) as resp:
logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4} current_try={5}".format(method, url, params, files, request_timeout, current_try).replace(token, token.split(':')[0] + ":{TOKEN}")) logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4} current_try={5}".format(method, url, params, files, request_timeout, current_try).replace(token, token.split(':')[0] + ":{TOKEN}"))
json_result = await _check_result(url, response) json_result = await _check_result(url, resp)
if json_result: if json_result:
return json_result['result'] return json_result['result']
except (ApiTelegramException,ApiInvalidJSONException, ApiHTTPException) as e: except (ApiTelegramException,ApiInvalidJSONException, ApiHTTPException) as e:
raise e raise e
except: except aiohttp.ClientError as e:
pass logger.error('Aiohttp ClientError: {0}'.format(e.__class__.__name__))
except Exception as e:
logger.error(f'Unkown error: {e.__class__.__name__}')
if not got_result: if not got_result:
raise RequestTimeout("Request timeout. Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout, current_try)) raise RequestTimeout("Request timeout. Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout, current_try))
@ -143,8 +160,7 @@ async def download_file(token, file_path):
else: else:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
url = FILE_URL.format(token, file_path) url = FILE_URL.format(token, file_path)
# TODO: rewrite this method async with await session_manager.get_session() as session:
async with await session_manager._get_new_session() as session:
async with session.get(url, proxy=proxy) as response: async with session.get(url, proxy=proxy) as response:
result = await response.read() result = await response.read()
if response.status != 200: if response.status != 200:
@ -279,7 +295,7 @@ async def send_message(
return await _process_request(token, method_name, params=params) return await _process_request(token, method_name, params=params)
# here shit begins # methods
async def get_user_profile_photos(token, user_id, offset=None, limit=None): async def get_user_profile_photos(token, user_id, offset=None, limit=None):
method_url = r'getUserProfilePhotos' method_url = r'getUserProfilePhotos'

View File

@ -0,0 +1,13 @@
from telebot.asyncio_storage.memory_storage import StateMemoryStorage
from telebot.asyncio_storage.redis_storage import StateRedisStorage
from telebot.asyncio_storage.pickle_storage import StatePickleStorage
from telebot.asyncio_storage.base_storage import StateContext,StateStorageBase
__all__ = [
'StateStorageBase', 'StateContext',
'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
]

View File

@ -0,0 +1,69 @@
import copy
class StateStorageBase:
def __init__(self) -> None:
pass
async def set_data(self, chat_id, user_id, key, value):
"""
Set data for a user in a particular chat.
"""
raise NotImplementedError
async def get_data(self, chat_id, user_id):
"""
Get data for a user in a particular chat.
"""
raise NotImplementedError
async def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user.
! Note that you should create a
record if it does not exist, and
if a record with state already exists,
you need to update a record.
"""
raise NotImplementedError
async def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user.
"""
raise NotImplementedError
async def reset_data(self, chat_id, user_id):
"""
Reset data for a particular user in a chat.
"""
raise NotImplementedError
async def get_state(self, chat_id, user_id):
raise NotImplementedError
async def save(chat_id, user_id, data):
raise NotImplementedError
class StateContext:
"""
Class for data.
"""
def __init__(self, obj, chat_id, user_id):
self.obj = obj
self.data = None
self.chat_id = chat_id
self.user_id = user_id
async def __aenter__(self):
self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id))
return self.data
async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self.obj.save(self.chat_id, self.user_id, self.data)

View File

@ -0,0 +1,64 @@
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
class StateMemoryStorage(StateStorageBase):
def __init__(self) -> None:
self.data = {}
#
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
async def set_state(self, chat_id, user_id, state):
if chat_id in self.data:
if user_id in self.data[chat_id]:
self.data[chat_id][user_id]['state'] = state
return True
else:
self.data[chat_id][user_id] = {'state': state, 'data': {}}
return True
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
return True
async def delete_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
del self.data[chat_id][user_id]
if chat_id == user_id:
del self.data[chat_id]
return True
return False
async def get_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['state']
return None
async def get_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['data']
return None
async def reset_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'] = {}
return True
return False
async def set_data(self, chat_id, user_id, key, value):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'][key] = value
return True
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
def get_interactive_data(self, chat_id, user_id):
return StateContext(self, chat_id, user_id)
async def save(self, chat_id, user_id, data):
self.data[chat_id][user_id]['data'] = data

View File

@ -0,0 +1,107 @@
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
import os
import pickle
class StatePickleStorage(StateStorageBase):
def __init__(self, file_path="./.state-save/states.pkl") -> None:
self.file_path = file_path
self.create_dir()
self.data = self.read()
async def convert_old_to_new(self):
# old looks like:
# {1: {'state': 'start', 'data': {'name': 'John'}}
# we should update old version pickle to new.
# new looks like:
# {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
new_data = {}
for key, value in self.data.items():
# this returns us id and dict with data and state
new_data[key] = {key: value} # convert this to new
# pass it to global data
self.data = new_data
self.update_data() # update data in file
def create_dir(self):
"""
Create directory .save-handlers.
"""
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
with open(self.file_path,'wb') as file:
pickle.dump({}, file)
def read(self):
file = open(self.file_path, 'rb')
data = pickle.load(file)
file.close()
return data
def update_data(self):
file = open(self.file_path, 'wb+')
pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
file.close()
async def set_state(self, chat_id, user_id, state):
if chat_id in self.data:
if user_id in self.data[chat_id]:
self.data[chat_id][user_id]['state'] = state
return True
else:
self.data[chat_id][user_id] = {'state': state, 'data': {}}
return True
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
self.update_data()
return True
async def delete_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
del self.data[chat_id][user_id]
if chat_id == user_id:
del self.data[chat_id]
self.update_data()
return True
return False
async def get_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['state']
return None
async def get_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['data']
return None
async def reset_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'] = {}
self.update_data()
return True
return False
async def set_data(self, chat_id, user_id, key, value):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'][key] = value
self.update_data()
return True
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
def get_interactive_data(self, chat_id, user_id):
return StateContext(self, chat_id, user_id)
async def save(self, chat_id, user_id, data):
self.data[chat_id][user_id]['data'] = data
self.update_data()

View File

@ -0,0 +1,178 @@
from pickle import FALSE
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
import json
redis_installed = True
try:
import aioredis
except:
redis_installed = False
class StateRedisStorage(StateStorageBase):
"""
This class is for Redis storage.
This will work only for states.
To use it, just pass this class to:
TeleBot(storage=StateRedisStorage())
"""
def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'):
if not redis_installed:
raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"')
aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0]))
if aioredis_version < (2,):
raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0')
self.redis = aioredis.Redis(host=host, port=port, db=db, password=password)
self.prefix = prefix
#self.con = Redis(connection_pool=self.redis) -> use this when necessary
#
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
async def get_record(self, key):
"""
Function to get record from database.
It has nothing to do with states.
Made for backend compatibility
"""
result = await self.redis.get(self.prefix+str(key))
if result: return json.loads(result)
return
async def set_record(self, key, value):
"""
Function to set record to database.
It has nothing to do with states.
Made for backend compatibility
"""
await self.redis.set(self.prefix+str(key), json.dumps(value))
return True
async def delete_record(self, key):
"""
Function to delete record from database.
It has nothing to do with states.
Made for backend compatibility
"""
await self.redis.delete(self.prefix+str(key))
return True
async def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user in a chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['state'] = state
else:
response[user_id] = {'state': state, 'data': {}}
else:
response = {user_id: {'state': state, 'data': {}}}
await self.set_record(chat_id, response)
return True
async def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user in a chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
del response[user_id]
if user_id == str(chat_id):
await self.delete_record(chat_id)
return True
else: await self.set_record(chat_id, response)
return True
return False
async def get_value(self, chat_id, user_id, key):
"""
Get value for a data of a user in a chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
if key in response[user_id]['data']:
return response[user_id]['data'][key]
return None
async def get_state(self, chat_id, user_id):
"""
Get state of a user in a chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
return response[user_id]['state']
return None
async def get_data(self, chat_id, user_id):
"""
Get data of particular user in a particular chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
return response[user_id]['data']
return None
async def reset_data(self, chat_id, user_id):
"""
Reset data of a user in a chat.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'] = {}
await self.set_record(chat_id, response)
return True
async def set_data(self, chat_id, user_id, key, value):
"""
Set data without interactive data.
"""
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'][key] = value
await self.set_record(chat_id, response)
return True
return False
def get_interactive_data(self, chat_id, user_id):
"""
Get Data in interactive way.
You can use with() with this function.
"""
return StateContext(self, chat_id, user_id)
async def save(self, chat_id, user_id, data):
response = await self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'] = dict(data, **response[user_id]['data'])
await self.set_record(chat_id, response)
return True

View File

@ -158,11 +158,21 @@ class StateFilter(AdvancedCustomFilter):
key = 'state' key = 'state'
def check(self, message, text): def check(self, message, text):
if self.bot.current_states.current_state(message.from_user.id) is False: return False if text == '*': return True
elif text == '*': return True if message.chat.type == 'group':
elif type(text) is list: return self.bot.current_states.current_state(message.from_user.id) in text group_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id)
return self.bot.current_states.current_state(message.from_user.id) == text if group_state == text:
return True
elif group_state in text and type(text) is list:
return True
else:
user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id)
if user_state == text:
return True
elif type(text) is list and user_state in text:
return True
class IsDigitFilter(SimpleCustomFilter): class IsDigitFilter(SimpleCustomFilter):
""" """
Filter to check whether the string is made up of only digits. Filter to check whether the string is made up of only digits.

View File

@ -3,6 +3,11 @@ import pickle
import threading import threading
from telebot import apihelper from telebot import apihelper
try:
from redis import Redis
redis_installed = True
except:
redis_installed = False
class HandlerBackend(object): class HandlerBackend(object):
@ -116,7 +121,8 @@ class FileHandlerBackend(HandlerBackend):
class RedisHandlerBackend(HandlerBackend): class RedisHandlerBackend(HandlerBackend):
def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None): def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None):
super(RedisHandlerBackend, self).__init__(handlers) super(RedisHandlerBackend, self).__init__(handlers)
from redis import Redis if not redis_installed:
raise Exception("Redis is not installed. Install it via 'pip install redis'")
self.prefix = prefix self.prefix = prefix
self.redis = Redis(host, port, db, password) self.redis = Redis(host, port, db, password)
@ -143,197 +149,19 @@ class RedisHandlerBackend(HandlerBackend):
return handlers return handlers
class StateMemory: class State:
def __init__(self): def __init__(self) -> None:
self._states = {} self.name = None
def __str__(self) -> str:
def add_state(self, chat_id, state): return self.name
"""
Add a state.
:param chat_id:
:param state: new state
"""
if chat_id in self._states:
self._states[chat_id]['state'] = state
else:
self._states[chat_id] = {'state': state,'data': {}}
def current_state(self, chat_id):
"""Current state"""
if chat_id in self._states: return self._states[chat_id]['state']
else: return False
def delete_state(self, chat_id):
"""Delete a state"""
self._states.pop(chat_id)
def get_data(self, chat_id):
return self._states[chat_id]['data']
def set(self, chat_id, new_state):
"""
Set a new state for a user.
:param chat_id:
:param new_state: new_state of a user
"""
self.add_state(chat_id,new_state)
def add_data(self, chat_id, key, value):
result = self._states[chat_id]['data'][key] = value
return result
def finish(self, chat_id):
"""
Finish(delete) state of a user.
:param chat_id:
"""
self.delete_state(chat_id)
def retrieve_data(self, chat_id):
"""
Save input text.
Usage:
with bot.retrieve_data(message.chat.id) as data:
data['name'] = message.text
Also, at the end of your 'Form' you can get the name:
data['name']
"""
return StateContext(self, chat_id)
class StateFile: class StatesGroup:
""" def __init_subclass__(cls) -> None:
Class to save states in a file. # print all variables of a subclass
""" for name, value in cls.__dict__.items():
def __init__(self, filename): if not name.startswith('__') and not callable(value) and isinstance(value, State):
self.file_path = filename # change value of that variable
value.name = ':'.join((cls.__name__, name))
def add_state(self, chat_id, state):
"""
Add a state.
:param chat_id:
:param state: new state
"""
states_data = self.read_data()
if chat_id in states_data:
states_data[chat_id]['state'] = state
return self.save_data(states_data)
else:
states_data[chat_id] = {'state': state,'data': {}}
return self.save_data(states_data)
def current_state(self, chat_id):
"""Current state."""
states_data = self.read_data()
if chat_id in states_data: return states_data[chat_id]['state']
else: return False
def delete_state(self, chat_id):
"""Delete a state"""
states_data = self.read_data()
states_data.pop(chat_id)
self.save_data(states_data)
def read_data(self):
"""
Read the data from file.
"""
file = open(self.file_path, 'rb')
states_data = pickle.load(file)
file.close()
return states_data
def create_dir(self):
"""
Create directory .save-handlers.
"""
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
with open(self.file_path,'wb') as file:
pickle.dump({}, file)
def save_data(self, new_data):
"""
Save data after editing.
:param new_data:
"""
with open(self.file_path, 'wb+') as state_file:
pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL)
return True
def get_data(self, chat_id):
return self.read_data()[chat_id]['data']
def set(self, chat_id, new_state):
"""
Set a new state for a user.
:param chat_id:
:param new_state: new_state of a user
"""
self.add_state(chat_id,new_state)
def add_data(self, chat_id, key, value):
states_data = self.read_data()
result = states_data[chat_id]['data'][key] = value
self.save_data(result)
return result
def finish(self, chat_id):
"""
Finish(delete) state of a user.
:param chat_id:
"""
self.delete_state(chat_id)
def retrieve_data(self, chat_id):
"""
Save input text.
Usage:
with bot.retrieve_data(message.chat.id) as data:
data['name'] = message.text
Also, at the end of your 'Form' you can get the name:
data['name']
"""
return StateFileContext(self, chat_id)
class StateContext:
"""
Class for data.
"""
def __init__(self , obj: StateMemory, chat_id) -> None:
self.obj = obj
self.chat_id = chat_id
self.data = obj.get_data(chat_id)
def __enter__(self):
return self.data
def __exit__(self, exc_type, exc_val, exc_tb):
return
class StateFileContext:
"""
Class for data.
"""
def __init__(self , obj: StateFile, chat_id) -> None:
self.obj = obj
self.chat_id = chat_id
self.data = self.obj.get_data(self.chat_id)
def __enter__(self):
return self.data
def __exit__(self, exc_type, exc_val, exc_tb):
old_data = self.obj.read_data()
for i in self.data:
old_data[self.chat_id]['data'][i] = self.data.get(i)
self.obj.save_data(old_data)
return

View File

@ -0,0 +1,13 @@
from telebot.storage.memory_storage import StateMemoryStorage
from telebot.storage.redis_storage import StateRedisStorage
from telebot.storage.pickle_storage import StatePickleStorage
from telebot.storage.base_storage import StateContext,StateStorageBase
__all__ = [
'StateStorageBase', 'StateContext',
'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
]

View File

@ -0,0 +1,65 @@
import copy
class StateStorageBase:
def __init__(self) -> None:
pass
def set_data(self, chat_id, user_id, key, value):
"""
Set data for a user in a particular chat.
"""
raise NotImplementedError
def get_data(self, chat_id, user_id):
"""
Get data for a user in a particular chat.
"""
raise NotImplementedError
def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user.
! Note that you should create a
record if it does not exist, and
if a record with state already exists,
you need to update a record.
"""
raise NotImplementedError
def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user.
"""
raise NotImplementedError
def reset_data(self, chat_id, user_id):
"""
Reset data for a particular user in a chat.
"""
raise NotImplementedError
def get_state(self, chat_id, user_id):
raise NotImplementedError
def save(chat_id, user_id, data):
raise NotImplementedError
class StateContext:
"""
Class for data.
"""
def __init__(self , obj, chat_id, user_id) -> None:
self.obj = obj
self.data = copy.deepcopy(obj.get_data(chat_id, user_id))
self.chat_id = chat_id
self.user_id = user_id
def __enter__(self):
return self.data
def __exit__(self, exc_type, exc_val, exc_tb):
return self.obj.save(self.chat_id, self.user_id, self.data)

View File

@ -0,0 +1,64 @@
from telebot.storage.base_storage import StateStorageBase, StateContext
class StateMemoryStorage(StateStorageBase):
def __init__(self) -> None:
self.data = {}
#
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
def set_state(self, chat_id, user_id, state):
if chat_id in self.data:
if user_id in self.data[chat_id]:
self.data[chat_id][user_id]['state'] = state
return True
else:
self.data[chat_id][user_id] = {'state': state, 'data': {}}
return True
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
return True
def delete_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
del self.data[chat_id][user_id]
if chat_id == user_id:
del self.data[chat_id]
return True
return False
def get_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['state']
return None
def get_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['data']
return None
def reset_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'] = {}
return True
return False
def set_data(self, chat_id, user_id, key, value):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'][key] = value
return True
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
def get_interactive_data(self, chat_id, user_id):
return StateContext(self, chat_id, user_id)
def save(self, chat_id, user_id, data):
self.data[chat_id][user_id]['data'] = data

View File

@ -0,0 +1,112 @@
from telebot.storage.base_storage import StateStorageBase, StateContext
import os
import pickle
class StatePickleStorage(StateStorageBase):
def __init__(self, file_path="./.state-save/states.pkl") -> None:
self.file_path = file_path
self.create_dir()
self.data = self.read()
def convert_old_to_new(self):
"""
Use this function to convert old storage to new storage.
This function is for people who was using pickle storage
that was in version <=4.3.1.
"""
# old looks like:
# {1: {'state': 'start', 'data': {'name': 'John'}}
# we should update old version pickle to new.
# new looks like:
# {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
new_data = {}
for key, value in self.data.items():
# this returns us id and dict with data and state
new_data[key] = {key: value} # convert this to new
# pass it to global data
self.data = new_data
self.update_data() # update data in file
def create_dir(self):
"""
Create directory .save-handlers.
"""
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
with open(self.file_path,'wb') as file:
pickle.dump({}, file)
def read(self):
file = open(self.file_path, 'rb')
data = pickle.load(file)
file.close()
return data
def update_data(self):
file = open(self.file_path, 'wb+')
pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
file.close()
def set_state(self, chat_id, user_id, state):
if chat_id in self.data:
if user_id in self.data[chat_id]:
self.data[chat_id][user_id]['state'] = state
return True
else:
self.data[chat_id][user_id] = {'state': state, 'data': {}}
return True
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
self.update_data()
return True
def delete_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
del self.data[chat_id][user_id]
if chat_id == user_id:
del self.data[chat_id]
self.update_data()
return True
return False
def get_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['state']
return None
def get_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['data']
return None
def reset_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'] = {}
self.update_data()
return True
return False
def set_data(self, chat_id, user_id, key, value):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'][key] = value
self.update_data()
return True
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
def get_interactive_data(self, chat_id, user_id):
return StateContext(self, chat_id, user_id)
def save(self, chat_id, user_id, data):
self.data[chat_id][user_id]['data'] = data
self.update_data()

View File

@ -0,0 +1,176 @@
from telebot.storage.base_storage import StateStorageBase, StateContext
import json
redis_installed = True
try:
from redis import Redis, ConnectionPool
except:
redis_installed = False
class StateRedisStorage(StateStorageBase):
"""
This class is for Redis storage.
This will work only for states.
To use it, just pass this class to:
TeleBot(storage=StateRedisStorage())
"""
def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'):
self.redis = ConnectionPool(host=host, port=port, db=db, password=password)
#self.con = Redis(connection_pool=self.redis) -> use this when necessary
#
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
self.prefix = prefix
if not redis_installed:
raise Exception("Redis is not installed. Install it via 'pip install redis'")
def get_record(self, key):
"""
Function to get record from database.
It has nothing to do with states.
Made for backend compatibility
"""
connection = Redis(connection_pool=self.redis)
result = connection.get(self.prefix+str(key))
connection.close()
if result: return json.loads(result)
return
def set_record(self, key, value):
"""
Function to set record to database.
It has nothing to do with states.
Made for backend compatibility
"""
connection = Redis(connection_pool=self.redis)
connection.set(self.prefix+str(key), json.dumps(value))
connection.close()
return True
def delete_record(self, key):
"""
Function to delete record from database.
It has nothing to do with states.
Made for backend compatibility
"""
connection = Redis(connection_pool=self.redis)
connection.delete(self.prefix+str(key))
connection.close()
return True
def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user in a chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['state'] = state
else:
response[user_id] = {'state': state, 'data': {}}
else:
response = {user_id: {'state': state, 'data': {}}}
self.set_record(chat_id, response)
return True
def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user in a chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
del response[user_id]
if user_id == str(chat_id):
self.delete_record(chat_id)
return True
else: self.set_record(chat_id, response)
return True
return False
def get_value(self, chat_id, user_id, key):
"""
Get value for a data of a user in a chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
if key in response[user_id]['data']:
return response[user_id]['data'][key]
return None
def get_state(self, chat_id, user_id):
"""
Get state of a user in a chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
return response[user_id]['state']
return None
def get_data(self, chat_id, user_id):
"""
Get data of particular user in a particular chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
return response[user_id]['data']
return None
def reset_data(self, chat_id, user_id):
"""
Reset data of a user in a chat.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'] = {}
self.set_record(chat_id, response)
return True
def set_data(self, chat_id, user_id, key, value):
"""
Set data without interactive data.
"""
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'][key] = value
self.set_record(chat_id, response)
return True
return False
def get_interactive_data(self, chat_id, user_id):
"""
Get Data in interactive way.
You can use with() with this function.
"""
return StateContext(self, chat_id, user_id)
def save(self, chat_id, user_id, data):
response = self.get_record(chat_id)
user_id = str(user_id)
if response:
if user_id in response:
response[user_id]['data'] = dict(data, **response[user_id]['data'])
self.set_record(chat_id, response)
return True

0
tests/__init__.py Normal file
View File