mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
Merge pull request #1421 from coder2020official/master
RedisStorage, middleware fix, pass_bot parameter and more
This commit is contained in:
commit
2e9947277a
@ -1,15 +1,28 @@
|
||||
import telebot
|
||||
from telebot import asyncio_filters
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
bot = AsyncTeleBot('TOKEN')
|
||||
|
||||
# list of storages, you can use any storage
|
||||
from telebot.asyncio_storage import StateRedisStorage,StateMemoryStorage,StatePickleStorage
|
||||
|
||||
# new feature for states.
|
||||
from telebot.asyncio_handler_backends import State, StatesGroup
|
||||
|
||||
# default state storage is statememorystorage
|
||||
bot = AsyncTeleBot('TOKEN', state_storage=StateMemoryStorage())
|
||||
|
||||
|
||||
# Just create different statesgroup
|
||||
class MyStates(StatesGroup):
|
||||
name = State() # statesgroup should contain states
|
||||
surname = State()
|
||||
age = State()
|
||||
|
||||
|
||||
|
||||
class MyStates:
|
||||
name = 1
|
||||
surname = 2
|
||||
age = 3
|
||||
|
||||
# set_state -> sets a new state
|
||||
# delete_state -> delets state if exists
|
||||
# get_state -> returns state if exists
|
||||
|
||||
|
||||
@bot.message_handler(commands=['start'])
|
||||
@ -17,7 +30,7 @@ async def start_ex(message):
|
||||
"""
|
||||
Start command. Here we are starting state
|
||||
"""
|
||||
await bot.set_state(message.from_user.id, MyStates.name)
|
||||
await bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
|
||||
await bot.send_message(message.chat.id, 'Hi, write me a name')
|
||||
|
||||
|
||||
@ -28,39 +41,45 @@ async def any_state(message):
|
||||
Cancel state
|
||||
"""
|
||||
await bot.send_message(message.chat.id, "Your state was cancelled.")
|
||||
await bot.delete_state(message.from_user.id)
|
||||
await bot.delete_state(message.from_user.id, message.chat.id)
|
||||
|
||||
@bot.message_handler(state=MyStates.name)
|
||||
async def name_get(message):
|
||||
"""
|
||||
State 1. Will process when user's state is 1.
|
||||
State 1. Will process when user's state is MyStates.name.
|
||||
"""
|
||||
await bot.send_message(message.chat.id, f'Now write me a surname')
|
||||
await bot.set_state(message.from_user.id, MyStates.surname)
|
||||
async with bot.retrieve_data(message.from_user.id) as data:
|
||||
await bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
|
||||
async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
|
||||
@bot.message_handler(state=MyStates.surname)
|
||||
async def ask_age(message):
|
||||
"""
|
||||
State 2. Will process when user's state is 2.
|
||||
State 2. Will process when user's state is MyStates.surname.
|
||||
"""
|
||||
await bot.send_message(message.chat.id, "What is your age?")
|
||||
await bot.set_state(message.from_user.id, MyStates.age)
|
||||
async with bot.retrieve_data(message.from_user.id) as data:
|
||||
await bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
|
||||
async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
data['surname'] = message.text
|
||||
|
||||
# result
|
||||
@bot.message_handler(state=MyStates.age, is_digit=True)
|
||||
async def ready_for_answer(message):
|
||||
async with bot.retrieve_data(message.from_user.id) as data:
|
||||
"""
|
||||
State 3. Will process when user's state is MyStates.age.
|
||||
"""
|
||||
async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
await bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
|
||||
await bot.delete_state(message.from_user.id)
|
||||
await bot.delete_state(message.from_user.id, message.chat.id)
|
||||
|
||||
#incorrect number
|
||||
@bot.message_handler(state=MyStates.age, is_digit=False)
|
||||
async def age_incorrect(message):
|
||||
"""
|
||||
Will process for wrong input when state is MyState.age
|
||||
"""
|
||||
await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
|
||||
|
||||
# register filters
|
||||
@ -68,8 +87,6 @@ async def age_incorrect(message):
|
||||
bot.add_custom_filter(asyncio_filters.StateFilter(bot))
|
||||
bot.add_custom_filter(asyncio_filters.IsDigitFilter())
|
||||
|
||||
# set saving states into file.
|
||||
bot.enable_saving_states() # you can delete this if you do not need to save states
|
||||
|
||||
import asyncio
|
||||
asyncio.run(bot.polling())
|
@ -1,14 +1,39 @@
|
||||
import telebot
|
||||
import telebot # telebot
|
||||
|
||||
from telebot import custom_filters
|
||||
from telebot.handler_backends import State, StatesGroup #States
|
||||
|
||||
bot = telebot.TeleBot("")
|
||||
# States storage
|
||||
from telebot.storage import StateRedisStorage, StatePickleStorage, StateMemoryStorage
|
||||
|
||||
|
||||
class MyStates:
|
||||
name = 1
|
||||
surname = 2
|
||||
age = 3
|
||||
# Beginning from version 4.4.0+, we support storages.
|
||||
# StateRedisStorage -> Redis-based storage.
|
||||
# StatePickleStorage -> Pickle-based storage.
|
||||
# For redis, you will need to install redis.
|
||||
# Pass host, db, password, or anything else,
|
||||
# if you need to change config for redis.
|
||||
# Pickle requires path. Default path is in folder .state-saves.
|
||||
# If you were using older version of pytba for pickle,
|
||||
# you need to migrate from old pickle to new by using
|
||||
# StatePickleStorage().convert_old_to_new()
|
||||
|
||||
|
||||
|
||||
# Now, you can pass storage to bot.
|
||||
state_storage = StateMemoryStorage() # you can init here another storage
|
||||
|
||||
bot = telebot.TeleBot("TOKEN",
|
||||
state_storage=state_storage)
|
||||
|
||||
|
||||
# States group.
|
||||
class MyStates(StatesGroup):
|
||||
# Just name variables differently
|
||||
name = State() # creating instances of State class is enough from now
|
||||
surname = State()
|
||||
age = State()
|
||||
|
||||
|
||||
|
||||
|
||||
@ -17,50 +42,56 @@ def start_ex(message):
|
||||
"""
|
||||
Start command. Here we are starting state
|
||||
"""
|
||||
bot.set_state(message.from_user.id, MyStates.name)
|
||||
bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
|
||||
bot.send_message(message.chat.id, 'Hi, write me a name')
|
||||
|
||||
|
||||
|
||||
# Any state
|
||||
@bot.message_handler(state="*", commands='cancel')
|
||||
def any_state(message):
|
||||
"""
|
||||
Cancel state
|
||||
"""
|
||||
bot.send_message(message.chat.id, "Your state was cancelled.")
|
||||
bot.delete_state(message.from_user.id)
|
||||
bot.delete_state(message.from_user.id, message.chat.id)
|
||||
|
||||
@bot.message_handler(state=MyStates.name)
|
||||
def name_get(message):
|
||||
"""
|
||||
State 1. Will process when user's state is 1.
|
||||
State 1. Will process when user's state is MyStates.name.
|
||||
"""
|
||||
bot.send_message(message.chat.id, f'Now write me a surname')
|
||||
bot.set_state(message.from_user.id, MyStates.surname)
|
||||
with bot.retrieve_data(message.from_user.id) as data:
|
||||
bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
|
||||
with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
|
||||
@bot.message_handler(state=MyStates.surname)
|
||||
def ask_age(message):
|
||||
"""
|
||||
State 2. Will process when user's state is 2.
|
||||
State 2. Will process when user's state is MyStates.surname.
|
||||
"""
|
||||
bot.send_message(message.chat.id, "What is your age?")
|
||||
bot.set_state(message.from_user.id, MyStates.age)
|
||||
with bot.retrieve_data(message.from_user.id) as data:
|
||||
bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
|
||||
with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
data['surname'] = message.text
|
||||
|
||||
# result
|
||||
@bot.message_handler(state=MyStates.age, is_digit=True)
|
||||
def ready_for_answer(message):
|
||||
with bot.retrieve_data(message.from_user.id) as data:
|
||||
"""
|
||||
State 3. Will process when user's state is MyStates.age.
|
||||
"""
|
||||
with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
|
||||
bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
|
||||
bot.delete_state(message.from_user.id)
|
||||
bot.delete_state(message.from_user.id, message.chat.id)
|
||||
|
||||
#incorrect number
|
||||
@bot.message_handler(state=MyStates.age, is_digit=False)
|
||||
def age_incorrect(message):
|
||||
"""
|
||||
Wrong response for MyStates.age
|
||||
"""
|
||||
bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
|
||||
|
||||
# register filters
|
||||
@ -68,7 +99,4 @@ def age_incorrect(message):
|
||||
bot.add_custom_filter(custom_filters.StateFilter(bot))
|
||||
bot.add_custom_filter(custom_filters.IsDigitFilter())
|
||||
|
||||
# set saving states into file.
|
||||
bot.enable_saving_states() # you can delete this if you do not need to save states
|
||||
|
||||
bot.infinity_polling(skip_pending=True)
|
@ -13,7 +13,8 @@ from typing import Any, Callable, List, Optional, Union
|
||||
import telebot.util
|
||||
import telebot.types
|
||||
|
||||
|
||||
# storage
|
||||
from telebot.storage import StatePickleStorage, StateMemoryStorage
|
||||
|
||||
logger = logging.getLogger('TeleBot')
|
||||
|
||||
@ -28,7 +29,7 @@ logger.addHandler(console_output_handler)
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
from telebot import apihelper, util, types
|
||||
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend, StateMemory, StateFile
|
||||
from telebot.handler_backends import MemoryHandlerBackend, FileHandlerBackend
|
||||
from telebot.custom_filters import SimpleCustomFilter, AdvancedCustomFilter
|
||||
|
||||
|
||||
@ -148,7 +149,7 @@ class TeleBot:
|
||||
def __init__(
|
||||
self, token, parse_mode=None, threaded=True, skip_pending=False, num_threads=2,
|
||||
next_step_backend=None, reply_backend=None, exception_handler=None, last_update_id=0,
|
||||
suppress_middleware_excepions=False
|
||||
suppress_middleware_excepions=False, state_storage=StateMemoryStorage()
|
||||
):
|
||||
"""
|
||||
:param token: bot API token
|
||||
@ -193,7 +194,7 @@ class TeleBot:
|
||||
self.custom_filters = {}
|
||||
self.state_handlers = []
|
||||
|
||||
self.current_states = StateMemory()
|
||||
self.current_states = state_storage
|
||||
|
||||
|
||||
if apihelper.ENABLE_MIDDLEWARE:
|
||||
@ -251,7 +252,7 @@ class TeleBot:
|
||||
:param filename: Filename of saving file
|
||||
"""
|
||||
|
||||
self.current_states = StateFile(filename=filename)
|
||||
self.current_states = StatePickleStorage(filename=filename)
|
||||
self.current_states.create_dir()
|
||||
|
||||
def enable_save_reply_handlers(self, delay=120, filename="./.handler-saves/reply.save"):
|
||||
@ -777,6 +778,13 @@ class TeleBot:
|
||||
logger.info('Stopped polling.')
|
||||
|
||||
def _exec_task(self, task, *args, **kwargs):
|
||||
if kwargs.get('task_type') == 'handler':
|
||||
pass_bot = kwargs.get('pass_bot')
|
||||
kwargs.pop('pass_bot')
|
||||
kwargs.pop('task_type')
|
||||
if pass_bot:
|
||||
kwargs['bot'] = self
|
||||
|
||||
if self.threaded:
|
||||
self.worker_pool.put(task, *args, **kwargs)
|
||||
else:
|
||||
@ -2531,40 +2539,59 @@ class TeleBot:
|
||||
chat_id = message.chat.id
|
||||
self.register_next_step_handler_by_chat_id(chat_id, callback, *args, **kwargs)
|
||||
|
||||
def set_state(self, chat_id: int, state: Union[int, str]):
|
||||
def set_state(self, user_id: int, state: Union[int, str], chat_id: int=None) -> None:
|
||||
"""
|
||||
Sets a new state of a user.
|
||||
:param chat_id:
|
||||
:param state: new state. can be string or integer.
|
||||
"""
|
||||
self.current_states.add_state(chat_id, state)
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
self.current_states.set_state(chat_id, user_id, state)
|
||||
|
||||
def delete_state(self, chat_id: int):
|
||||
def reset_data(self, user_id: int, chat_id: int=None):
|
||||
"""
|
||||
Reset data for a user in chat.
|
||||
:param user_id:
|
||||
:param chat_id:
|
||||
"""
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
self.current_states.reset_data(chat_id, user_id)
|
||||
def delete_state(self, user_id: int, chat_id: int=None) -> None:
|
||||
"""
|
||||
Delete the current state of a user.
|
||||
:param chat_id:
|
||||
:return:
|
||||
"""
|
||||
self.current_states.delete_state(chat_id)
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
self.current_states.delete_state(chat_id, user_id)
|
||||
|
||||
def retrieve_data(self, chat_id: int):
|
||||
return self.current_states.retrieve_data(chat_id)
|
||||
def retrieve_data(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]:
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
return self.current_states.get_interactive_data(chat_id, user_id)
|
||||
|
||||
def get_state(self, chat_id: int):
|
||||
def get_state(self, user_id: int, chat_id: int=None) -> Optional[Union[int, str]]:
|
||||
"""
|
||||
Get current state of a user.
|
||||
:param chat_id:
|
||||
:return: state of a user
|
||||
"""
|
||||
return self.current_states.current_state(chat_id)
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
return self.current_states.get_state(chat_id, user_id)
|
||||
|
||||
def add_data(self, chat_id: int, **kwargs):
|
||||
def add_data(self, user_id: int, chat_id:int=None, **kwargs):
|
||||
"""
|
||||
Add data to states.
|
||||
:param chat_id:
|
||||
"""
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
for key, value in kwargs.items():
|
||||
self.current_states.add_data(chat_id, key, value)
|
||||
self.current_states.set_data(chat_id, user_id, key, value)
|
||||
|
||||
def register_next_step_handler_by_chat_id(
|
||||
self, chat_id: Union[int, str], callback: Callable, *args, **kwargs) -> None:
|
||||
@ -2632,7 +2659,7 @@ class TeleBot:
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _build_handler_dict(handler, **filters):
|
||||
def _build_handler_dict(handler, pass_bot=False, **filters):
|
||||
"""
|
||||
Builds a dictionary for a handler
|
||||
:param handler:
|
||||
@ -2641,6 +2668,7 @@ class TeleBot:
|
||||
"""
|
||||
return {
|
||||
'function': handler,
|
||||
'pass_bot': pass_bot,
|
||||
'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None}
|
||||
# Remove None values, they are skipped in _test_filter anyway
|
||||
#'filters': filters
|
||||
@ -2686,7 +2714,7 @@ class TeleBot:
|
||||
:return:
|
||||
"""
|
||||
if not apihelper.ENABLE_MIDDLEWARE:
|
||||
raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE.")
|
||||
raise RuntimeError("Middleware is not enabled. Use apihelper.ENABLE_MIDDLEWARE before initialising TeleBot.")
|
||||
|
||||
if update_types:
|
||||
for update_type in update_types:
|
||||
@ -2694,6 +2722,27 @@ class TeleBot:
|
||||
else:
|
||||
self.default_middleware_handlers.append(handler)
|
||||
|
||||
# function register_middleware_handler
|
||||
def register_middleware_handler(self, callback, update_types=None):
|
||||
"""
|
||||
Middleware handler decorator.
|
||||
|
||||
This function will create a decorator that can be used to decorate functions that must be handled as middlewares before entering any other
|
||||
message handlers
|
||||
But, be careful and check type of the update inside the handler if more than one update_type is given
|
||||
|
||||
Example:
|
||||
|
||||
bot = TeleBot('TOKEN')
|
||||
|
||||
bot.register_middleware_handler(print_channel_post_text, update_types=['channel_post', 'edited_channel_post'])
|
||||
|
||||
:param update_types: Optional list of update types that can be passed into the middleware handler.
|
||||
|
||||
"""
|
||||
|
||||
self.add_middleware_handler(callback, update_types)
|
||||
|
||||
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
|
||||
"""
|
||||
Message handler decorator.
|
||||
@ -2766,7 +2815,7 @@ class TeleBot:
|
||||
"""
|
||||
self.message_handlers.append(handler_dict)
|
||||
|
||||
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs):
|
||||
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers message handler.
|
||||
:param callback: function to be called
|
||||
@ -2775,6 +2824,7 @@ class TeleBot:
|
||||
:param regexp:
|
||||
:param func:
|
||||
:param chat_types: True for private chat
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
@ -2791,6 +2841,7 @@ class TeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_message_handler(handler_dict)
|
||||
|
||||
@ -2838,7 +2889,7 @@ class TeleBot:
|
||||
"""
|
||||
self.edited_message_handlers.append(handler_dict)
|
||||
|
||||
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs):
|
||||
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers edited message handler.
|
||||
:param callback: function to be called
|
||||
@ -2847,6 +2898,7 @@ class TeleBot:
|
||||
:param regexp:
|
||||
:param func:
|
||||
:param chat_types: True for private chat
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
@ -2863,6 +2915,7 @@ class TeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_edited_message_handler(handler_dict)
|
||||
|
||||
@ -2908,7 +2961,7 @@ class TeleBot:
|
||||
"""
|
||||
self.channel_post_handlers.append(handler_dict)
|
||||
|
||||
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs):
|
||||
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers channel post message handler.
|
||||
:param callback: function to be called
|
||||
@ -2916,6 +2969,7 @@ class TeleBot:
|
||||
:param commands: list of commands
|
||||
:param regexp:
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
@ -2931,6 +2985,7 @@ class TeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_channel_post_handler(handler_dict)
|
||||
|
||||
@ -2975,7 +3030,7 @@ class TeleBot:
|
||||
"""
|
||||
self.edited_channel_post_handlers.append(handler_dict)
|
||||
|
||||
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs):
|
||||
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers edited channel post message handler.
|
||||
:param callback: function to be called
|
||||
@ -2983,6 +3038,7 @@ class TeleBot:
|
||||
:param commands: list of commands
|
||||
:param regexp:
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
if isinstance(commands, str):
|
||||
@ -2998,6 +3054,7 @@ class TeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_edited_channel_post_handler(handler_dict)
|
||||
|
||||
@ -3024,14 +3081,15 @@ class TeleBot:
|
||||
"""
|
||||
self.inline_handlers.append(handler_dict)
|
||||
|
||||
def register_inline_handler(self, callback, func, **kwargs):
|
||||
def register_inline_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers inline handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_inline_handler(handler_dict)
|
||||
|
||||
def chosen_inline_handler(self, func, **kwargs):
|
||||
@ -3057,14 +3115,15 @@ class TeleBot:
|
||||
"""
|
||||
self.chosen_inline_handlers.append(handler_dict)
|
||||
|
||||
def register_chosen_inline_handler(self, callback, func, **kwargs):
|
||||
def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chosen inline handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chosen_inline_handler(handler_dict)
|
||||
|
||||
def callback_query_handler(self, func, **kwargs):
|
||||
@ -3090,14 +3149,15 @@ class TeleBot:
|
||||
"""
|
||||
self.callback_query_handlers.append(handler_dict)
|
||||
|
||||
def register_callback_query_handler(self, callback, func, **kwargs):
|
||||
def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers callback query handler..
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_callback_query_handler(handler_dict)
|
||||
|
||||
def shipping_query_handler(self, func, **kwargs):
|
||||
@ -3123,14 +3183,15 @@ class TeleBot:
|
||||
"""
|
||||
self.shipping_query_handlers.append(handler_dict)
|
||||
|
||||
def register_shipping_query_handler(self, callback, func, **kwargs):
|
||||
def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers shipping query handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_shipping_query_handler(handler_dict)
|
||||
|
||||
def pre_checkout_query_handler(self, func, **kwargs):
|
||||
@ -3156,14 +3217,15 @@ class TeleBot:
|
||||
"""
|
||||
self.pre_checkout_query_handlers.append(handler_dict)
|
||||
|
||||
def register_pre_checkout_query_handler(self, callback, func, **kwargs):
|
||||
def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers pre-checkout request handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_pre_checkout_query_handler(handler_dict)
|
||||
|
||||
def poll_handler(self, func, **kwargs):
|
||||
@ -3189,14 +3251,15 @@ class TeleBot:
|
||||
"""
|
||||
self.poll_handlers.append(handler_dict)
|
||||
|
||||
def register_poll_handler(self, callback, func, **kwargs):
|
||||
def register_poll_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers poll handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_poll_handler(handler_dict)
|
||||
|
||||
def poll_answer_handler(self, func=None, **kwargs):
|
||||
@ -3222,14 +3285,15 @@ class TeleBot:
|
||||
"""
|
||||
self.poll_answer_handlers.append(handler_dict)
|
||||
|
||||
def register_poll_answer_handler(self, callback, func, **kwargs):
|
||||
def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers poll answer handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_poll_answer_handler(handler_dict)
|
||||
|
||||
def my_chat_member_handler(self, func=None, **kwargs):
|
||||
@ -3255,14 +3319,15 @@ class TeleBot:
|
||||
"""
|
||||
self.my_chat_member_handlers.append(handler_dict)
|
||||
|
||||
def register_my_chat_member_handler(self, callback, func=None, **kwargs):
|
||||
def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers my chat member handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_my_chat_member_handler(handler_dict)
|
||||
|
||||
def chat_member_handler(self, func=None, **kwargs):
|
||||
@ -3288,14 +3353,15 @@ class TeleBot:
|
||||
"""
|
||||
self.chat_member_handlers.append(handler_dict)
|
||||
|
||||
def register_chat_member_handler(self, callback, func=None, **kwargs):
|
||||
def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chat member handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chat_member_handler(handler_dict)
|
||||
|
||||
def chat_join_request_handler(self, func=None, **kwargs):
|
||||
@ -3321,14 +3387,15 @@ class TeleBot:
|
||||
"""
|
||||
self.chat_join_request_handlers.append(handler_dict)
|
||||
|
||||
def register_chat_join_request_handler(self, callback, func=None, **kwargs):
|
||||
def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chat join request handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:param pass_bot: Pass TeleBot to handler.
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chat_join_request_handler(handler_dict)
|
||||
|
||||
def _test_message_handler(self, message_handler, message):
|
||||
@ -3409,7 +3476,7 @@ class TeleBot:
|
||||
for message in new_messages:
|
||||
for message_handler in handlers:
|
||||
if self._test_message_handler(message_handler, message):
|
||||
self._exec_task(message_handler['function'], message)
|
||||
self._exec_task(message_handler['function'], message, pass_bot=message_handler['pass_bot'], task_type='handler')
|
||||
break
|
||||
|
||||
|
||||
|
@ -13,6 +13,9 @@ import telebot.util
|
||||
import telebot.types
|
||||
|
||||
|
||||
# storages
|
||||
from telebot.asyncio_storage import StateMemoryStorage, StatePickleStorage
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from telebot import logger
|
||||
@ -161,7 +164,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
|
||||
def __init__(self, token: str, parse_mode: Optional[str]=None, offset=None,
|
||||
exception_handler=None) -> None: # TODO: ADD TYPEHINTS
|
||||
exception_handler=None, state_storage=StateMemoryStorage()) -> None: # TODO: ADD TYPEHINTS
|
||||
self.token = token
|
||||
|
||||
self.offset = offset
|
||||
@ -190,12 +193,13 @@ class AsyncTeleBot:
|
||||
self.custom_filters = {}
|
||||
self.state_handlers = []
|
||||
|
||||
self.current_states = asyncio_handler_backends.StateMemory()
|
||||
self.current_states = state_storage
|
||||
|
||||
|
||||
self.middlewares = []
|
||||
|
||||
|
||||
async def close_session(self):
|
||||
await asyncio_helper.session_manager.session.close()
|
||||
async def get_updates(self, offset: Optional[int]=None, limit: Optional[int]=None,
|
||||
timeout: Optional[int]=None, allowed_updates: Optional[List]=None, request_timeout: Optional[int]=None) -> List[types.Update]:
|
||||
json_updates = await asyncio_helper.get_updates(self.token, offset, limit, timeout, allowed_updates, request_timeout)
|
||||
@ -299,7 +303,7 @@ class AsyncTeleBot:
|
||||
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
|
||||
if updates:
|
||||
self.offset = updates[-1].update_id + 1
|
||||
self._loop_create_task(self.process_new_updates(updates)) # Seperate task for processing updates
|
||||
asyncio.create_task(self.process_new_updates(updates)) # Seperate task for processing updates
|
||||
if interval: await asyncio.sleep(interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@ -322,6 +326,8 @@ class AsyncTeleBot:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error('Cause exception while getting updates.')
|
||||
if non_stop:
|
||||
@ -333,6 +339,7 @@ class AsyncTeleBot:
|
||||
|
||||
finally:
|
||||
self._polling = False
|
||||
await self.close_session()
|
||||
logger.warning('Polling is stopped.')
|
||||
|
||||
|
||||
@ -346,31 +353,48 @@ class AsyncTeleBot:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
tasks = []
|
||||
for message in messages:
|
||||
middleware = await self.process_middlewares(message, update_type)
|
||||
self._loop_create_task(self._run_middlewares_and_handlers(handlers, message, middleware))
|
||||
tasks.append(self._run_middlewares_and_handlers(handlers, message, middleware))
|
||||
asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def _run_middlewares_and_handlers(self, handlers, message, middleware):
|
||||
handler_error = None
|
||||
data = {}
|
||||
for message_handler in handlers:
|
||||
process_update = await self._test_message_handler(message_handler, message)
|
||||
if not process_update:
|
||||
continue
|
||||
elif process_update:
|
||||
process_handler = True
|
||||
if middleware:
|
||||
middleware_result = await middleware.pre_process(message, data)
|
||||
if isinstance(middleware_result, SkipHandler):
|
||||
await middleware.post_process(message, data, handler_error)
|
||||
break
|
||||
process_handler = False
|
||||
if isinstance(middleware_result, CancelUpdate):
|
||||
return
|
||||
for handler in handlers:
|
||||
if not process_handler:
|
||||
break
|
||||
|
||||
process_update = await self._test_message_handler(handler, message)
|
||||
if not process_update:
|
||||
continue
|
||||
elif process_update:
|
||||
try:
|
||||
if "data" in signature(message_handler['function']).parameters:
|
||||
await message_handler['function'](message, data)
|
||||
else:
|
||||
await message_handler['function'](message)
|
||||
params = []
|
||||
|
||||
for i in signature(handler['function']).parameters:
|
||||
params.append(i)
|
||||
if len(params) == 1:
|
||||
await handler['function'](message)
|
||||
break
|
||||
if params[1] == 'data' and handler.get('pass_bot') is True:
|
||||
await handler['function'](message, data, self)
|
||||
break
|
||||
elif params[1] == 'data' and handler.get('pass_bot') is False:
|
||||
await handler['function'](message, data)
|
||||
break
|
||||
elif params[1] != 'data' and handler.get('pass_bot') is True:
|
||||
await handler['function'](message, self)
|
||||
break
|
||||
except Exception as e:
|
||||
handler_error = e
|
||||
@ -687,7 +711,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.message_handlers.append(handler_dict)
|
||||
|
||||
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs):
|
||||
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers message handler.
|
||||
:param callback: function to be called
|
||||
@ -696,8 +720,11 @@ class AsyncTeleBot:
|
||||
:param regexp:
|
||||
:param func:
|
||||
:param chat_types: True for private chat
|
||||
:param pass_bot: True if you want to get TeleBot instance in your handler
|
||||
:return: decorated function
|
||||
"""
|
||||
if content_types is None:
|
||||
content_types = ["text"]
|
||||
if isinstance(commands, str):
|
||||
logger.warning("register_message_handler: 'commands' filter should be List of strings (commands), not string.")
|
||||
commands = [commands]
|
||||
@ -712,6 +739,7 @@ class AsyncTeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_message_handler(handler_dict)
|
||||
|
||||
@ -759,7 +787,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.edited_message_handlers.append(handler_dict)
|
||||
|
||||
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs):
|
||||
def register_edited_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers edited message handler.
|
||||
:param callback: function to be called
|
||||
@ -784,6 +812,7 @@ class AsyncTeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_edited_message_handler(handler_dict)
|
||||
|
||||
@ -829,7 +858,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.channel_post_handlers.append(handler_dict)
|
||||
|
||||
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs):
|
||||
def register_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers channel post message handler.
|
||||
:param callback: function to be called
|
||||
@ -852,6 +881,7 @@ class AsyncTeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_channel_post_handler(handler_dict)
|
||||
|
||||
@ -896,7 +926,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.edited_channel_post_handlers.append(handler_dict)
|
||||
|
||||
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, **kwargs):
|
||||
def register_edited_channel_post_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers edited channel post message handler.
|
||||
:param callback: function to be called
|
||||
@ -919,6 +949,7 @@ class AsyncTeleBot:
|
||||
commands=commands,
|
||||
regexp=regexp,
|
||||
func=func,
|
||||
pass_bot=pass_bot,
|
||||
**kwargs)
|
||||
self.add_edited_channel_post_handler(handler_dict)
|
||||
|
||||
@ -945,14 +976,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.inline_handlers.append(handler_dict)
|
||||
|
||||
def register_inline_handler(self, callback, func, **kwargs):
|
||||
def register_inline_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers inline handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_inline_handler(handler_dict)
|
||||
|
||||
def chosen_inline_handler(self, func, **kwargs):
|
||||
@ -978,14 +1009,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.chosen_inline_handlers.append(handler_dict)
|
||||
|
||||
def register_chosen_inline_handler(self, callback, func, **kwargs):
|
||||
def register_chosen_inline_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chosen inline handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chosen_inline_handler(handler_dict)
|
||||
|
||||
def callback_query_handler(self, func, **kwargs):
|
||||
@ -1011,14 +1042,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.callback_query_handlers.append(handler_dict)
|
||||
|
||||
def register_callback_query_handler(self, callback, func, **kwargs):
|
||||
def register_callback_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers callback query handler..
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_callback_query_handler(handler_dict)
|
||||
|
||||
def shipping_query_handler(self, func, **kwargs):
|
||||
@ -1044,14 +1075,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.shipping_query_handlers.append(handler_dict)
|
||||
|
||||
def register_shipping_query_handler(self, callback, func, **kwargs):
|
||||
def register_shipping_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers shipping query handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_shipping_query_handler(handler_dict)
|
||||
|
||||
def pre_checkout_query_handler(self, func, **kwargs):
|
||||
@ -1077,14 +1108,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.pre_checkout_query_handlers.append(handler_dict)
|
||||
|
||||
def register_pre_checkout_query_handler(self, callback, func, **kwargs):
|
||||
def register_pre_checkout_query_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers pre-checkout request handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_pre_checkout_query_handler(handler_dict)
|
||||
|
||||
def poll_handler(self, func, **kwargs):
|
||||
@ -1110,14 +1141,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.poll_handlers.append(handler_dict)
|
||||
|
||||
def register_poll_handler(self, callback, func, **kwargs):
|
||||
def register_poll_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers poll handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_poll_handler(handler_dict)
|
||||
|
||||
def poll_answer_handler(self, func=None, **kwargs):
|
||||
@ -1143,14 +1174,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.poll_answer_handlers.append(handler_dict)
|
||||
|
||||
def register_poll_answer_handler(self, callback, func, **kwargs):
|
||||
def register_poll_answer_handler(self, callback, func, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers poll answer handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_poll_answer_handler(handler_dict)
|
||||
|
||||
def my_chat_member_handler(self, func=None, **kwargs):
|
||||
@ -1176,14 +1207,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.my_chat_member_handlers.append(handler_dict)
|
||||
|
||||
def register_my_chat_member_handler(self, callback, func=None, **kwargs):
|
||||
def register_my_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers my chat member handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_my_chat_member_handler(handler_dict)
|
||||
|
||||
def chat_member_handler(self, func=None, **kwargs):
|
||||
@ -1209,14 +1240,14 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.chat_member_handlers.append(handler_dict)
|
||||
|
||||
def register_chat_member_handler(self, callback, func=None, **kwargs):
|
||||
def register_chat_member_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chat member handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chat_member_handler(handler_dict)
|
||||
|
||||
def chat_join_request_handler(self, func=None, **kwargs):
|
||||
@ -1242,18 +1273,18 @@ class AsyncTeleBot:
|
||||
"""
|
||||
self.chat_join_request_handlers.append(handler_dict)
|
||||
|
||||
def register_chat_join_request_handler(self, callback, func=None, **kwargs):
|
||||
def register_chat_join_request_handler(self, callback, func=None, pass_bot=False, **kwargs):
|
||||
"""
|
||||
Registers chat join request handler.
|
||||
:param callback: function to be called
|
||||
:param func:
|
||||
:return: decorated function
|
||||
"""
|
||||
handler_dict = self._build_handler_dict(callback, func=func, **kwargs)
|
||||
handler_dict = self._build_handler_dict(callback, func=func, pass_bot=pass_bot, **kwargs)
|
||||
self.add_chat_join_request_handler(handler_dict)
|
||||
|
||||
@staticmethod
|
||||
def _build_handler_dict(handler, **filters):
|
||||
def _build_handler_dict(handler, pass_bot=False, **filters):
|
||||
"""
|
||||
Builds a dictionary for a handler
|
||||
:param handler:
|
||||
@ -1262,6 +1293,7 @@ class AsyncTeleBot:
|
||||
"""
|
||||
return {
|
||||
'function': handler,
|
||||
'pass_bot': pass_bot,
|
||||
'filters': {ftype: fvalue for ftype, fvalue in filters.items() if fvalue is not None}
|
||||
# Remove None values, they are skipped in _test_filter anyway
|
||||
#'filters': filters
|
||||
@ -1324,8 +1356,7 @@ class AsyncTeleBot:
|
||||
:param filename: Filename of saving file
|
||||
"""
|
||||
|
||||
self.current_states = asyncio_handler_backends.StateFile(filename=filename)
|
||||
self.current_states.create_dir()
|
||||
self.current_states = StatePickleStorage(file_path=filename)
|
||||
|
||||
async def set_webhook(self, url=None, certificate=None, max_connections=None, allowed_updates=None, ip_address=None,
|
||||
drop_pending_updates = None, timeout=None):
|
||||
@ -1356,6 +1387,8 @@ class AsyncTeleBot:
|
||||
return await asyncio_helper.set_webhook(self.token, url, certificate, max_connections, allowed_updates, ip_address,
|
||||
drop_pending_updates, timeout)
|
||||
|
||||
|
||||
|
||||
async def delete_webhook(self, drop_pending_updates=None, timeout=None):
|
||||
"""
|
||||
Use this method to remove webhook integration if you decide to switch back to getUpdates.
|
||||
@ -1366,6 +1399,12 @@ class AsyncTeleBot:
|
||||
"""
|
||||
return await asyncio_helper.delete_webhook(self.token, drop_pending_updates, timeout)
|
||||
|
||||
async def remove_webhook(self):
|
||||
"""
|
||||
Alternative for delete_webhook but uses set_webhook
|
||||
"""
|
||||
self.set_webhook()
|
||||
|
||||
async def get_webhook_info(self, timeout=None):
|
||||
"""
|
||||
Use this method to get current webhook status. Requires no parameters.
|
||||
@ -2443,8 +2482,8 @@ class AsyncTeleBot:
|
||||
"""
|
||||
return await asyncio_helper.delete_chat_photo(self.token, chat_id)
|
||||
|
||||
async def get_my_commands(self, scope: Optional[types.BotCommandScope]=None,
|
||||
language_code: Optional[str]=None) -> List[types.BotCommand]:
|
||||
async def get_my_commands(self, scope: Optional[types.BotCommandScope],
|
||||
language_code: Optional[str]) -> List[types.BotCommand]:
|
||||
"""
|
||||
Use this method to get the current list of the bot's commands.
|
||||
Returns List of BotCommand on success.
|
||||
@ -3019,37 +3058,57 @@ class AsyncTeleBot:
|
||||
return await asyncio_helper.delete_sticker_from_set(self.token, sticker)
|
||||
|
||||
|
||||
async def set_state(self, chat_id, state):
|
||||
async def set_state(self, user_id: int, state: str, chat_id: int=None):
|
||||
"""
|
||||
Sets a new state of a user.
|
||||
:param chat_id:
|
||||
:param state: new state. can be string or integer.
|
||||
"""
|
||||
await self.current_states.add_state(chat_id, state)
|
||||
if not chat_id:
|
||||
chat_id = user_id
|
||||
await self.current_states.set_state(chat_id, user_id, state)
|
||||
|
||||
async def delete_state(self, chat_id):
|
||||
async def reset_data(self, user_id: int, chat_id: int=None):
|
||||
"""
|
||||
Reset data for a user in chat.
|
||||
:param user_id:
|
||||
:param chat_id:
|
||||
"""
|
||||
if chat_id is None:
|
||||
chat_id = user_id
|
||||
await self.current_states.reset_data(chat_id, user_id)
|
||||
|
||||
async def delete_state(self, user_id: int, chat_id:int=None):
|
||||
"""
|
||||
Delete the current state of a user.
|
||||
:param chat_id:
|
||||
:return:
|
||||
"""
|
||||
await self.current_states.delete_state(chat_id)
|
||||
if not chat_id:
|
||||
chat_id = user_id
|
||||
await self.current_states.delete_state(chat_id, user_id)
|
||||
|
||||
def retrieve_data(self, chat_id):
|
||||
return self.current_states.retrieve_data(chat_id)
|
||||
def retrieve_data(self, user_id: int, chat_id: int=None):
|
||||
if not chat_id:
|
||||
chat_id = user_id
|
||||
return self.current_states.get_interactive_data(chat_id, user_id)
|
||||
|
||||
async def get_state(self, chat_id):
|
||||
async def get_state(self, user_id, chat_id: int=None):
|
||||
"""
|
||||
Get current state of a user.
|
||||
:param chat_id:
|
||||
:return: state of a user
|
||||
"""
|
||||
return await self.current_states.current_state(chat_id)
|
||||
if not chat_id:
|
||||
chat_id = user_id
|
||||
return await self.current_states.get_state(chat_id, user_id)
|
||||
|
||||
async def add_data(self, chat_id, **kwargs):
|
||||
async def add_data(self, user_id: int, chat_id: int=None, **kwargs):
|
||||
"""
|
||||
Add data to states.
|
||||
:param chat_id:
|
||||
"""
|
||||
if not chat_id:
|
||||
chat_id = user_id
|
||||
for key, value in kwargs.items():
|
||||
await self.current_states.add_data(chat_id, key, value)
|
||||
await self.current_states.set_data(chat_id, user_id, key, value)
|
||||
|
@ -159,11 +159,21 @@ class StateFilter(AdvancedCustomFilter):
|
||||
key = 'state'
|
||||
|
||||
async def check(self, message, text):
|
||||
result = await self.bot.current_states.current_state(message.from_user.id)
|
||||
if result is False: return False
|
||||
elif text == '*': return True
|
||||
elif type(text) is list: return result in text
|
||||
return result == text
|
||||
if text == '*': return True
|
||||
if message.chat.type == 'group':
|
||||
group_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id)
|
||||
if group_state == text:
|
||||
return True
|
||||
elif group_state in text and type(text) is list:
|
||||
return True
|
||||
|
||||
|
||||
else:
|
||||
user_state = await self.bot.current_states.get_state(message.chat.id,message.from_user.id)
|
||||
if user_state == text:
|
||||
return True
|
||||
elif type(text) is list and user_state in text:
|
||||
return True
|
||||
|
||||
class IsDigitFilter(SimpleCustomFilter):
|
||||
"""
|
||||
|
@ -3,206 +3,6 @@ import pickle
|
||||
|
||||
|
||||
|
||||
class StateMemory:
|
||||
def __init__(self):
|
||||
self._states = {}
|
||||
|
||||
async def add_state(self, chat_id, state):
|
||||
"""
|
||||
Add a state.
|
||||
:param chat_id:
|
||||
:param state: new state
|
||||
"""
|
||||
if chat_id in self._states:
|
||||
|
||||
self._states[chat_id]['state'] = state
|
||||
else:
|
||||
self._states[chat_id] = {'state': state,'data': {}}
|
||||
|
||||
async def current_state(self, chat_id):
|
||||
"""Current state"""
|
||||
if chat_id in self._states: return self._states[chat_id]['state']
|
||||
else: return False
|
||||
|
||||
async def delete_state(self, chat_id):
|
||||
"""Delete a state"""
|
||||
self._states.pop(chat_id)
|
||||
|
||||
def get_data(self, chat_id):
|
||||
return self._states[chat_id]['data']
|
||||
|
||||
async def set(self, chat_id, new_state):
|
||||
"""
|
||||
Set a new state for a user.
|
||||
:param chat_id:
|
||||
:param new_state: new_state of a user
|
||||
"""
|
||||
await self.add_state(chat_id,new_state)
|
||||
|
||||
async def add_data(self, chat_id, key, value):
|
||||
result = self._states[chat_id]['data'][key] = value
|
||||
return result
|
||||
|
||||
async def finish(self, chat_id):
|
||||
"""
|
||||
Finish(delete) state of a user.
|
||||
:param chat_id:
|
||||
"""
|
||||
await self.delete_state(chat_id)
|
||||
|
||||
def retrieve_data(self, chat_id):
|
||||
"""
|
||||
Save input text.
|
||||
|
||||
Usage:
|
||||
with bot.retrieve_data(message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
Also, at the end of your 'Form' you can get the name:
|
||||
data['name']
|
||||
"""
|
||||
return StateContext(self, chat_id)
|
||||
|
||||
|
||||
class StateFile:
|
||||
"""
|
||||
Class to save states in a file.
|
||||
"""
|
||||
def __init__(self, filename):
|
||||
self.file_path = filename
|
||||
|
||||
async def add_state(self, chat_id, state):
|
||||
"""
|
||||
Add a state.
|
||||
:param chat_id:
|
||||
:param state: new state
|
||||
"""
|
||||
states_data = self.read_data()
|
||||
if chat_id in states_data:
|
||||
states_data[chat_id]['state'] = state
|
||||
return await self.save_data(states_data)
|
||||
else:
|
||||
states_data[chat_id] = {'state': state,'data': {}}
|
||||
return await self.save_data(states_data)
|
||||
|
||||
|
||||
async def current_state(self, chat_id):
|
||||
"""Current state."""
|
||||
states_data = self.read_data()
|
||||
if chat_id in states_data: return states_data[chat_id]['state']
|
||||
else: return False
|
||||
|
||||
async def delete_state(self, chat_id):
|
||||
"""Delete a state"""
|
||||
states_data = self.read_data()
|
||||
states_data.pop(chat_id)
|
||||
await self.save_data(states_data)
|
||||
|
||||
def read_data(self):
|
||||
"""
|
||||
Read the data from file.
|
||||
"""
|
||||
file = open(self.file_path, 'rb')
|
||||
states_data = pickle.load(file)
|
||||
file.close()
|
||||
return states_data
|
||||
|
||||
def create_dir(self):
|
||||
"""
|
||||
Create directory .save-handlers.
|
||||
"""
|
||||
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
|
||||
os.makedirs(dirs, exist_ok=True)
|
||||
if not os.path.isfile(self.file_path):
|
||||
with open(self.file_path,'wb') as file:
|
||||
pickle.dump({}, file)
|
||||
|
||||
async def save_data(self, new_data):
|
||||
"""
|
||||
Save data after editing.
|
||||
:param new_data:
|
||||
"""
|
||||
with open(self.file_path, 'wb+') as state_file:
|
||||
pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return True
|
||||
|
||||
def get_data(self, chat_id):
|
||||
return self.read_data()[chat_id]['data']
|
||||
|
||||
async def set(self, chat_id, new_state):
|
||||
"""
|
||||
Set a new state for a user.
|
||||
:param chat_id:
|
||||
:param new_state: new_state of a user
|
||||
|
||||
"""
|
||||
await self.add_state(chat_id,new_state)
|
||||
|
||||
async def add_data(self, chat_id, key, value):
|
||||
states_data = self.read_data()
|
||||
result = states_data[chat_id]['data'][key] = value
|
||||
await self.save_data(result)
|
||||
|
||||
return result
|
||||
|
||||
async def finish(self, chat_id):
|
||||
"""
|
||||
Finish(delete) state of a user.
|
||||
:param chat_id:
|
||||
"""
|
||||
await self.delete_state(chat_id)
|
||||
|
||||
def retrieve_data(self, chat_id):
|
||||
"""
|
||||
Save input text.
|
||||
|
||||
Usage:
|
||||
with bot.retrieve_data(message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
Also, at the end of your 'Form' you can get the name:
|
||||
data['name']
|
||||
"""
|
||||
return StateFileContext(self, chat_id)
|
||||
|
||||
|
||||
class StateContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
def __init__(self , obj: StateMemory, chat_id) -> None:
|
||||
self.obj = obj
|
||||
self.chat_id = chat_id
|
||||
self.data = obj.get_data(chat_id)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.data
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return
|
||||
|
||||
class StateFileContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
def __init__(self , obj: StateFile, chat_id) -> None:
|
||||
self.obj = obj
|
||||
self.chat_id = chat_id
|
||||
self.data = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.data = self.obj.get_data(self.chat_id)
|
||||
return self.data
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
old_data = self.obj.read_data()
|
||||
for i in self.data:
|
||||
old_data[self.chat_id]['data'][i] = self.data.get(i)
|
||||
await self.obj.save_data(old_data)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
"""
|
||||
Base class for middleware.
|
||||
@ -217,3 +17,19 @@ class BaseMiddleware:
|
||||
async def post_process(self, message, data, exception):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self) -> None:
|
||||
self.name = None
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class StatesGroup:
|
||||
def __init_subclass__(cls) -> None:
|
||||
# print all variables of a subclass
|
||||
for name, value in cls.__dict__.items():
|
||||
if not name.startswith('__') and not callable(value) and isinstance(value, State):
|
||||
# change value of that variable
|
||||
value.name = ':'.join((cls.__name__, name))
|
||||
|
||||
|
@ -12,16 +12,8 @@ API_URL = 'https://api.telegram.org/bot{0}/{1}'
|
||||
from datetime import datetime
|
||||
|
||||
import telebot
|
||||
from telebot import util
|
||||
from telebot import util, logger
|
||||
|
||||
class SessionBase:
|
||||
def __init__(self) -> None:
|
||||
self.session = None
|
||||
async def _get_new_session(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self.session
|
||||
|
||||
session_manager = SessionBase()
|
||||
|
||||
proxy = None
|
||||
session = None
|
||||
@ -36,6 +28,29 @@ REQUEST_TIMEOUT = 10
|
||||
MAX_RETRIES = 3
|
||||
logger = telebot.logger
|
||||
|
||||
|
||||
REQUEST_LIMIT = 50
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self) -> None:
|
||||
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT))
|
||||
|
||||
async def create_session(self):
|
||||
self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=REQUEST_LIMIT))
|
||||
return self.session
|
||||
|
||||
async def get_session(self):
|
||||
if self.session.closed:
|
||||
self.session = await self.create_session()
|
||||
|
||||
if not self.session._loop.is_running():
|
||||
await self.session.close()
|
||||
self.session = await self.create_session()
|
||||
return self.session
|
||||
|
||||
|
||||
session_manager = SessionManager()
|
||||
|
||||
async def _process_request(token, url, method='get', params=None, files=None, request_timeout=None):
|
||||
params = prepare_data(params, files)
|
||||
if request_timeout is None:
|
||||
@ -43,19 +58,21 @@ async def _process_request(token, url, method='get', params=None, files=None, re
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout)
|
||||
got_result = False
|
||||
current_try=0
|
||||
async with await session_manager._get_new_session() as session:
|
||||
session = await session_manager.get_session()
|
||||
while not got_result and current_try<MAX_RETRIES-1:
|
||||
current_try +=1
|
||||
try:
|
||||
response = await session.request(method=method, url=API_URL.format(token, url), data=params, timeout=timeout)
|
||||
async with session.request(method=method, url=API_URL.format(token, url), data=params, timeout=timeout) as resp:
|
||||
logger.debug("Request: method={0} url={1} params={2} files={3} request_timeout={4} current_try={5}".format(method, url, params, files, request_timeout, current_try).replace(token, token.split(':')[0] + ":{TOKEN}"))
|
||||
json_result = await _check_result(url, response)
|
||||
json_result = await _check_result(url, resp)
|
||||
if json_result:
|
||||
return json_result['result']
|
||||
except (ApiTelegramException,ApiInvalidJSONException, ApiHTTPException) as e:
|
||||
raise e
|
||||
except:
|
||||
pass
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error('Aiohttp ClientError: {0}'.format(e.__class__.__name__))
|
||||
except Exception as e:
|
||||
logger.error(f'Unkown error: {e.__class__.__name__}')
|
||||
if not got_result:
|
||||
raise RequestTimeout("Request timeout. Request: method={0} url={1} params={2} files={3} request_timeout={4}".format(method, url, params, files, request_timeout, current_try))
|
||||
|
||||
@ -143,8 +160,7 @@ async def download_file(token, file_path):
|
||||
else:
|
||||
# noinspection PyUnresolvedReferences
|
||||
url = FILE_URL.format(token, file_path)
|
||||
# TODO: rewrite this method
|
||||
async with await session_manager._get_new_session() as session:
|
||||
async with await session_manager.get_session() as session:
|
||||
async with session.get(url, proxy=proxy) as response:
|
||||
result = await response.read()
|
||||
if response.status != 200:
|
||||
@ -279,7 +295,7 @@ async def send_message(
|
||||
|
||||
return await _process_request(token, method_name, params=params)
|
||||
|
||||
# here shit begins
|
||||
# methods
|
||||
|
||||
async def get_user_profile_photos(token, user_id, offset=None, limit=None):
|
||||
method_url = r'getUserProfilePhotos'
|
||||
|
13
telebot/asyncio_storage/__init__.py
Normal file
13
telebot/asyncio_storage/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from telebot.asyncio_storage.memory_storage import StateMemoryStorage
|
||||
from telebot.asyncio_storage.redis_storage import StateRedisStorage
|
||||
from telebot.asyncio_storage.pickle_storage import StatePickleStorage
|
||||
from telebot.asyncio_storage.base_storage import StateContext,StateStorageBase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StateStorageBase', 'StateContext',
|
||||
'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
|
||||
]
|
69
telebot/asyncio_storage/base_storage.py
Normal file
69
telebot/asyncio_storage/base_storage.py
Normal file
@ -0,0 +1,69 @@
|
||||
import copy
|
||||
|
||||
|
||||
class StateStorageBase:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
async def set_data(self, chat_id, user_id, key, value):
|
||||
"""
|
||||
Set data for a user in a particular chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get data for a user in a particular chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set_state(self, chat_id, user_id, state):
|
||||
"""
|
||||
Set state for a particular user.
|
||||
|
||||
! Note that you should create a
|
||||
record if it does not exist, and
|
||||
if a record with state already exists,
|
||||
you need to update a record.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_state(self, chat_id, user_id):
|
||||
"""
|
||||
Delete state for a particular user.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_data(self, chat_id, user_id):
|
||||
"""
|
||||
Reset data for a particular user in a chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_state(self, chat_id, user_id):
|
||||
raise NotImplementedError
|
||||
|
||||
async def save(chat_id, user_id, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class StateContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
|
||||
def __init__(self, obj, chat_id, user_id):
|
||||
self.obj = obj
|
||||
self.data = None
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id))
|
||||
return self.data
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return await self.obj.save(self.chat_id, self.user_id, self.data)
|
64
telebot/asyncio_storage/memory_storage.py
Normal file
64
telebot/asyncio_storage/memory_storage.py
Normal file
@ -0,0 +1,64 @@
|
||||
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
|
||||
|
||||
class StateMemoryStorage(StateStorageBase):
|
||||
def __init__(self) -> None:
|
||||
self.data = {}
|
||||
#
|
||||
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
|
||||
|
||||
|
||||
async def set_state(self, chat_id, user_id, state):
|
||||
if chat_id in self.data:
|
||||
if user_id in self.data[chat_id]:
|
||||
self.data[chat_id][user_id]['state'] = state
|
||||
return True
|
||||
else:
|
||||
self.data[chat_id][user_id] = {'state': state, 'data': {}}
|
||||
return True
|
||||
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
|
||||
return True
|
||||
|
||||
async def delete_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
del self.data[chat_id][user_id]
|
||||
if chat_id == user_id:
|
||||
del self.data[chat_id]
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['state']
|
||||
|
||||
return None
|
||||
async def get_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['data']
|
||||
|
||||
return None
|
||||
|
||||
async def reset_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'] = {}
|
||||
return True
|
||||
return False
|
||||
|
||||
async def set_data(self, chat_id, user_id, key, value):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'][key] = value
|
||||
return True
|
||||
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
async def save(self, chat_id, user_id, data):
|
||||
self.data[chat_id][user_id]['data'] = data
|
107
telebot/asyncio_storage/pickle_storage.py
Normal file
107
telebot/asyncio_storage/pickle_storage.py
Normal file
@ -0,0 +1,107 @@
|
||||
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
|
||||
import os
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
class StatePickleStorage(StateStorageBase):
|
||||
def __init__(self, file_path="./.state-save/states.pkl") -> None:
|
||||
self.file_path = file_path
|
||||
self.create_dir()
|
||||
self.data = self.read()
|
||||
|
||||
async def convert_old_to_new(self):
|
||||
# old looks like:
|
||||
# {1: {'state': 'start', 'data': {'name': 'John'}}
|
||||
# we should update old version pickle to new.
|
||||
# new looks like:
|
||||
# {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
|
||||
new_data = {}
|
||||
for key, value in self.data.items():
|
||||
# this returns us id and dict with data and state
|
||||
new_data[key] = {key: value} # convert this to new
|
||||
# pass it to global data
|
||||
self.data = new_data
|
||||
self.update_data() # update data in file
|
||||
|
||||
def create_dir(self):
|
||||
"""
|
||||
Create directory .save-handlers.
|
||||
"""
|
||||
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
|
||||
os.makedirs(dirs, exist_ok=True)
|
||||
if not os.path.isfile(self.file_path):
|
||||
with open(self.file_path,'wb') as file:
|
||||
pickle.dump({}, file)
|
||||
|
||||
def read(self):
|
||||
file = open(self.file_path, 'rb')
|
||||
data = pickle.load(file)
|
||||
file.close()
|
||||
return data
|
||||
|
||||
def update_data(self):
|
||||
file = open(self.file_path, 'wb+')
|
||||
pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
file.close()
|
||||
|
||||
async def set_state(self, chat_id, user_id, state):
|
||||
if chat_id in self.data:
|
||||
if user_id in self.data[chat_id]:
|
||||
self.data[chat_id][user_id]['state'] = state
|
||||
return True
|
||||
else:
|
||||
self.data[chat_id][user_id] = {'state': state, 'data': {}}
|
||||
return True
|
||||
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
|
||||
self.update_data()
|
||||
return True
|
||||
|
||||
async def delete_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
del self.data[chat_id][user_id]
|
||||
if chat_id == user_id:
|
||||
del self.data[chat_id]
|
||||
self.update_data()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['state']
|
||||
|
||||
return None
|
||||
async def get_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['data']
|
||||
|
||||
return None
|
||||
|
||||
async def reset_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'] = {}
|
||||
self.update_data()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def set_data(self, chat_id, user_id, key, value):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'][key] = value
|
||||
self.update_data()
|
||||
return True
|
||||
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
async def save(self, chat_id, user_id, data):
|
||||
self.data[chat_id][user_id]['data'] = data
|
||||
self.update_data()
|
178
telebot/asyncio_storage/redis_storage.py
Normal file
178
telebot/asyncio_storage/redis_storage.py
Normal file
@ -0,0 +1,178 @@
|
||||
from pickle import FALSE
|
||||
from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
|
||||
import json
|
||||
|
||||
redis_installed = True
|
||||
try:
|
||||
import aioredis
|
||||
except:
|
||||
redis_installed = False
|
||||
|
||||
|
||||
class StateRedisStorage(StateStorageBase):
|
||||
"""
|
||||
This class is for Redis storage.
|
||||
This will work only for states.
|
||||
To use it, just pass this class to:
|
||||
TeleBot(storage=StateRedisStorage())
|
||||
"""
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'):
|
||||
if not redis_installed:
|
||||
raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"')
|
||||
|
||||
|
||||
aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0]))
|
||||
if aioredis_version < (2,):
|
||||
raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0')
|
||||
self.redis = aioredis.Redis(host=host, port=port, db=db, password=password)
|
||||
|
||||
self.prefix = prefix
|
||||
#self.con = Redis(connection_pool=self.redis) -> use this when necessary
|
||||
#
|
||||
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
|
||||
|
||||
async def get_record(self, key):
|
||||
"""
|
||||
Function to get record from database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
result = await self.redis.get(self.prefix+str(key))
|
||||
if result: return json.loads(result)
|
||||
return
|
||||
|
||||
async def set_record(self, key, value):
|
||||
"""
|
||||
Function to set record to database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
|
||||
await self.redis.set(self.prefix+str(key), json.dumps(value))
|
||||
return True
|
||||
|
||||
async def delete_record(self, key):
|
||||
"""
|
||||
Function to delete record from database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
await self.redis.delete(self.prefix+str(key))
|
||||
return True
|
||||
|
||||
async def set_state(self, chat_id, user_id, state):
|
||||
"""
|
||||
Set state for a particular user in a chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['state'] = state
|
||||
else:
|
||||
response[user_id] = {'state': state, 'data': {}}
|
||||
else:
|
||||
response = {user_id: {'state': state, 'data': {}}}
|
||||
await self.set_record(chat_id, response)
|
||||
|
||||
return True
|
||||
|
||||
async def delete_state(self, chat_id, user_id):
|
||||
"""
|
||||
Delete state for a particular user in a chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
del response[user_id]
|
||||
if user_id == str(chat_id):
|
||||
await self.delete_record(chat_id)
|
||||
return True
|
||||
else: await self.set_record(chat_id, response)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_value(self, chat_id, user_id, key):
|
||||
"""
|
||||
Get value for a data of a user in a chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
if key in response[user_id]['data']:
|
||||
return response[user_id]['data'][key]
|
||||
return None
|
||||
|
||||
|
||||
async def get_state(self, chat_id, user_id):
|
||||
"""
|
||||
Get state of a user in a chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
return response[user_id]['state']
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get data of particular user in a particular chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
return response[user_id]['data']
|
||||
return None
|
||||
|
||||
|
||||
async def reset_data(self, chat_id, user_id):
|
||||
"""
|
||||
Reset data of a user in a chat.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'] = {}
|
||||
await self.set_record(chat_id, response)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
async def set_data(self, chat_id, user_id, key, value):
|
||||
"""
|
||||
Set data without interactive data.
|
||||
"""
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'][key] = value
|
||||
await self.set_record(chat_id, response)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get Data in interactive way.
|
||||
You can use with() with this function.
|
||||
"""
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
async def save(self, chat_id, user_id, data):
|
||||
response = await self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'] = dict(data, **response[user_id]['data'])
|
||||
await self.set_record(chat_id, response)
|
||||
return True
|
||||
|
@ -158,11 +158,21 @@ class StateFilter(AdvancedCustomFilter):
|
||||
key = 'state'
|
||||
|
||||
def check(self, message, text):
|
||||
if self.bot.current_states.current_state(message.from_user.id) is False: return False
|
||||
elif text == '*': return True
|
||||
elif type(text) is list: return self.bot.current_states.current_state(message.from_user.id) in text
|
||||
return self.bot.current_states.current_state(message.from_user.id) == text
|
||||
if text == '*': return True
|
||||
if message.chat.type == 'group':
|
||||
group_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id)
|
||||
if group_state == text:
|
||||
return True
|
||||
elif group_state in text and type(text) is list:
|
||||
return True
|
||||
|
||||
|
||||
else:
|
||||
user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id)
|
||||
if user_state == text:
|
||||
return True
|
||||
elif type(text) is list and user_state in text:
|
||||
return True
|
||||
class IsDigitFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Filter to check whether the string is made up of only digits.
|
||||
|
@ -3,6 +3,11 @@ import pickle
|
||||
import threading
|
||||
|
||||
from telebot import apihelper
|
||||
try:
|
||||
from redis import Redis
|
||||
redis_installed = True
|
||||
except:
|
||||
redis_installed = False
|
||||
|
||||
|
||||
class HandlerBackend(object):
|
||||
@ -116,7 +121,8 @@ class FileHandlerBackend(HandlerBackend):
|
||||
class RedisHandlerBackend(HandlerBackend):
|
||||
def __init__(self, handlers=None, host='localhost', port=6379, db=0, prefix='telebot', password=None):
|
||||
super(RedisHandlerBackend, self).__init__(handlers)
|
||||
from redis import Redis
|
||||
if not redis_installed:
|
||||
raise Exception("Redis is not installed. Install it via 'pip install redis'")
|
||||
self.prefix = prefix
|
||||
self.redis = Redis(host, port, db, password)
|
||||
|
||||
@ -143,197 +149,19 @@ class RedisHandlerBackend(HandlerBackend):
|
||||
return handlers
|
||||
|
||||
|
||||
class StateMemory:
|
||||
def __init__(self):
|
||||
self._states = {}
|
||||
|
||||
def add_state(self, chat_id, state):
|
||||
"""
|
||||
Add a state.
|
||||
:param chat_id:
|
||||
:param state: new state
|
||||
"""
|
||||
if chat_id in self._states:
|
||||
|
||||
self._states[chat_id]['state'] = state
|
||||
else:
|
||||
self._states[chat_id] = {'state': state,'data': {}}
|
||||
|
||||
def current_state(self, chat_id):
|
||||
"""Current state"""
|
||||
if chat_id in self._states: return self._states[chat_id]['state']
|
||||
else: return False
|
||||
|
||||
def delete_state(self, chat_id):
|
||||
"""Delete a state"""
|
||||
self._states.pop(chat_id)
|
||||
|
||||
def get_data(self, chat_id):
|
||||
return self._states[chat_id]['data']
|
||||
|
||||
def set(self, chat_id, new_state):
|
||||
"""
|
||||
Set a new state for a user.
|
||||
:param chat_id:
|
||||
:param new_state: new_state of a user
|
||||
"""
|
||||
self.add_state(chat_id,new_state)
|
||||
|
||||
def add_data(self, chat_id, key, value):
|
||||
result = self._states[chat_id]['data'][key] = value
|
||||
return result
|
||||
|
||||
def finish(self, chat_id):
|
||||
"""
|
||||
Finish(delete) state of a user.
|
||||
:param chat_id:
|
||||
"""
|
||||
self.delete_state(chat_id)
|
||||
|
||||
def retrieve_data(self, chat_id):
|
||||
"""
|
||||
Save input text.
|
||||
|
||||
Usage:
|
||||
with bot.retrieve_data(message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
Also, at the end of your 'Form' you can get the name:
|
||||
data['name']
|
||||
"""
|
||||
return StateContext(self, chat_id)
|
||||
class State:
|
||||
def __init__(self) -> None:
|
||||
self.name = None
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class StateFile:
|
||||
"""
|
||||
Class to save states in a file.
|
||||
"""
|
||||
def __init__(self, filename):
|
||||
self.file_path = filename
|
||||
|
||||
def add_state(self, chat_id, state):
|
||||
"""
|
||||
Add a state.
|
||||
:param chat_id:
|
||||
:param state: new state
|
||||
"""
|
||||
states_data = self.read_data()
|
||||
if chat_id in states_data:
|
||||
states_data[chat_id]['state'] = state
|
||||
return self.save_data(states_data)
|
||||
else:
|
||||
states_data[chat_id] = {'state': state,'data': {}}
|
||||
return self.save_data(states_data)
|
||||
|
||||
def current_state(self, chat_id):
|
||||
"""Current state."""
|
||||
states_data = self.read_data()
|
||||
if chat_id in states_data: return states_data[chat_id]['state']
|
||||
else: return False
|
||||
|
||||
def delete_state(self, chat_id):
|
||||
"""Delete a state"""
|
||||
states_data = self.read_data()
|
||||
states_data.pop(chat_id)
|
||||
self.save_data(states_data)
|
||||
|
||||
def read_data(self):
|
||||
"""
|
||||
Read the data from file.
|
||||
"""
|
||||
file = open(self.file_path, 'rb')
|
||||
states_data = pickle.load(file)
|
||||
file.close()
|
||||
return states_data
|
||||
|
||||
def create_dir(self):
|
||||
"""
|
||||
Create directory .save-handlers.
|
||||
"""
|
||||
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
|
||||
os.makedirs(dirs, exist_ok=True)
|
||||
if not os.path.isfile(self.file_path):
|
||||
with open(self.file_path,'wb') as file:
|
||||
pickle.dump({}, file)
|
||||
|
||||
def save_data(self, new_data):
|
||||
"""
|
||||
Save data after editing.
|
||||
:param new_data:
|
||||
"""
|
||||
with open(self.file_path, 'wb+') as state_file:
|
||||
pickle.dump(new_data, state_file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return True
|
||||
|
||||
def get_data(self, chat_id):
|
||||
return self.read_data()[chat_id]['data']
|
||||
|
||||
def set(self, chat_id, new_state):
|
||||
"""
|
||||
Set a new state for a user.
|
||||
:param chat_id:
|
||||
:param new_state: new_state of a user
|
||||
"""
|
||||
self.add_state(chat_id,new_state)
|
||||
|
||||
def add_data(self, chat_id, key, value):
|
||||
states_data = self.read_data()
|
||||
result = states_data[chat_id]['data'][key] = value
|
||||
self.save_data(result)
|
||||
return result
|
||||
|
||||
def finish(self, chat_id):
|
||||
"""
|
||||
Finish(delete) state of a user.
|
||||
:param chat_id:
|
||||
"""
|
||||
self.delete_state(chat_id)
|
||||
|
||||
def retrieve_data(self, chat_id):
|
||||
"""
|
||||
Save input text.
|
||||
|
||||
Usage:
|
||||
with bot.retrieve_data(message.chat.id) as data:
|
||||
data['name'] = message.text
|
||||
|
||||
Also, at the end of your 'Form' you can get the name:
|
||||
data['name']
|
||||
"""
|
||||
return StateFileContext(self, chat_id)
|
||||
class StatesGroup:
|
||||
def __init_subclass__(cls) -> None:
|
||||
# print all variables of a subclass
|
||||
for name, value in cls.__dict__.items():
|
||||
if not name.startswith('__') and not callable(value) and isinstance(value, State):
|
||||
# change value of that variable
|
||||
value.name = ':'.join((cls.__name__, name))
|
||||
|
||||
|
||||
class StateContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
def __init__(self , obj: StateMemory, chat_id) -> None:
|
||||
self.obj = obj
|
||||
self.chat_id = chat_id
|
||||
self.data = obj.get_data(chat_id)
|
||||
|
||||
def __enter__(self):
|
||||
return self.data
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return
|
||||
|
||||
|
||||
class StateFileContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
def __init__(self , obj: StateFile, chat_id) -> None:
|
||||
self.obj = obj
|
||||
self.chat_id = chat_id
|
||||
self.data = self.obj.get_data(self.chat_id)
|
||||
|
||||
def __enter__(self):
|
||||
return self.data
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
old_data = self.obj.read_data()
|
||||
for i in self.data:
|
||||
old_data[self.chat_id]['data'][i] = self.data.get(i)
|
||||
self.obj.save_data(old_data)
|
||||
return
|
||||
|
13
telebot/storage/__init__.py
Normal file
13
telebot/storage/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from telebot.storage.memory_storage import StateMemoryStorage
|
||||
from telebot.storage.redis_storage import StateRedisStorage
|
||||
from telebot.storage.pickle_storage import StatePickleStorage
|
||||
from telebot.storage.base_storage import StateContext,StateStorageBase
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StateStorageBase', 'StateContext',
|
||||
'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
|
||||
]
|
65
telebot/storage/base_storage.py
Normal file
65
telebot/storage/base_storage.py
Normal file
@ -0,0 +1,65 @@
|
||||
import copy
|
||||
|
||||
class StateStorageBase:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def set_data(self, chat_id, user_id, key, value):
|
||||
"""
|
||||
Set data for a user in a particular chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get data for a user in a particular chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_state(self, chat_id, user_id, state):
|
||||
"""
|
||||
Set state for a particular user.
|
||||
|
||||
! Note that you should create a
|
||||
record if it does not exist, and
|
||||
if a record with state already exists,
|
||||
you need to update a record.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_state(self, chat_id, user_id):
|
||||
"""
|
||||
Delete state for a particular user.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_data(self, chat_id, user_id):
|
||||
"""
|
||||
Reset data for a particular user in a chat.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_state(self, chat_id, user_id):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(chat_id, user_id, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class StateContext:
|
||||
"""
|
||||
Class for data.
|
||||
"""
|
||||
def __init__(self , obj, chat_id, user_id) -> None:
|
||||
self.obj = obj
|
||||
self.data = copy.deepcopy(obj.get_data(chat_id, user_id))
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self.data
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return self.obj.save(self.chat_id, self.user_id, self.data)
|
64
telebot/storage/memory_storage.py
Normal file
64
telebot/storage/memory_storage.py
Normal file
@ -0,0 +1,64 @@
|
||||
from telebot.storage.base_storage import StateStorageBase, StateContext
|
||||
|
||||
class StateMemoryStorage(StateStorageBase):
|
||||
def __init__(self) -> None:
|
||||
self.data = {}
|
||||
#
|
||||
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
|
||||
|
||||
|
||||
def set_state(self, chat_id, user_id, state):
|
||||
if chat_id in self.data:
|
||||
if user_id in self.data[chat_id]:
|
||||
self.data[chat_id][user_id]['state'] = state
|
||||
return True
|
||||
else:
|
||||
self.data[chat_id][user_id] = {'state': state, 'data': {}}
|
||||
return True
|
||||
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
|
||||
return True
|
||||
|
||||
def delete_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
del self.data[chat_id][user_id]
|
||||
if chat_id == user_id:
|
||||
del self.data[chat_id]
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['state']
|
||||
|
||||
return None
|
||||
def get_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['data']
|
||||
|
||||
return None
|
||||
|
||||
def reset_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'] = {}
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_data(self, chat_id, user_id, key, value):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'][key] = value
|
||||
return True
|
||||
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
def save(self, chat_id, user_id, data):
|
||||
self.data[chat_id][user_id]['data'] = data
|
112
telebot/storage/pickle_storage.py
Normal file
112
telebot/storage/pickle_storage.py
Normal file
@ -0,0 +1,112 @@
|
||||
from telebot.storage.base_storage import StateStorageBase, StateContext
|
||||
import os
|
||||
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
class StatePickleStorage(StateStorageBase):
|
||||
def __init__(self, file_path="./.state-save/states.pkl") -> None:
|
||||
self.file_path = file_path
|
||||
self.create_dir()
|
||||
self.data = self.read()
|
||||
|
||||
def convert_old_to_new(self):
|
||||
"""
|
||||
Use this function to convert old storage to new storage.
|
||||
This function is for people who was using pickle storage
|
||||
that was in version <=4.3.1.
|
||||
"""
|
||||
# old looks like:
|
||||
# {1: {'state': 'start', 'data': {'name': 'John'}}
|
||||
# we should update old version pickle to new.
|
||||
# new looks like:
|
||||
# {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
|
||||
new_data = {}
|
||||
for key, value in self.data.items():
|
||||
# this returns us id and dict with data and state
|
||||
new_data[key] = {key: value} # convert this to new
|
||||
# pass it to global data
|
||||
self.data = new_data
|
||||
self.update_data() # update data in file
|
||||
|
||||
def create_dir(self):
|
||||
"""
|
||||
Create directory .save-handlers.
|
||||
"""
|
||||
dirs = self.file_path.rsplit('/', maxsplit=1)[0]
|
||||
os.makedirs(dirs, exist_ok=True)
|
||||
if not os.path.isfile(self.file_path):
|
||||
with open(self.file_path,'wb') as file:
|
||||
pickle.dump({}, file)
|
||||
|
||||
def read(self):
|
||||
file = open(self.file_path, 'rb')
|
||||
data = pickle.load(file)
|
||||
file.close()
|
||||
return data
|
||||
|
||||
def update_data(self):
|
||||
file = open(self.file_path, 'wb+')
|
||||
pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
file.close()
|
||||
|
||||
def set_state(self, chat_id, user_id, state):
|
||||
if chat_id in self.data:
|
||||
if user_id in self.data[chat_id]:
|
||||
self.data[chat_id][user_id]['state'] = state
|
||||
return True
|
||||
else:
|
||||
self.data[chat_id][user_id] = {'state': state, 'data': {}}
|
||||
return True
|
||||
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
|
||||
self.update_data()
|
||||
return True
|
||||
|
||||
def delete_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
del self.data[chat_id][user_id]
|
||||
if chat_id == user_id:
|
||||
del self.data[chat_id]
|
||||
self.update_data()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_state(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['state']
|
||||
|
||||
return None
|
||||
def get_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
return self.data[chat_id][user_id]['data']
|
||||
|
||||
return None
|
||||
|
||||
def reset_data(self, chat_id, user_id):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'] = {}
|
||||
self.update_data()
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_data(self, chat_id, user_id, key, value):
|
||||
if self.data.get(chat_id):
|
||||
if self.data[chat_id].get(user_id):
|
||||
self.data[chat_id][user_id]['data'][key] = value
|
||||
self.update_data()
|
||||
return True
|
||||
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
def save(self, chat_id, user_id, data):
|
||||
self.data[chat_id][user_id]['data'] = data
|
||||
self.update_data()
|
176
telebot/storage/redis_storage.py
Normal file
176
telebot/storage/redis_storage.py
Normal file
@ -0,0 +1,176 @@
|
||||
from telebot.storage.base_storage import StateStorageBase, StateContext
|
||||
import json
|
||||
|
||||
redis_installed = True
|
||||
try:
|
||||
from redis import Redis, ConnectionPool
|
||||
|
||||
except:
|
||||
redis_installed = False
|
||||
|
||||
class StateRedisStorage(StateStorageBase):
|
||||
"""
|
||||
This class is for Redis storage.
|
||||
This will work only for states.
|
||||
To use it, just pass this class to:
|
||||
TeleBot(storage=StateRedisStorage())
|
||||
"""
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_'):
|
||||
self.redis = ConnectionPool(host=host, port=port, db=db, password=password)
|
||||
#self.con = Redis(connection_pool=self.redis) -> use this when necessary
|
||||
#
|
||||
# {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
|
||||
self.prefix = prefix
|
||||
if not redis_installed:
|
||||
raise Exception("Redis is not installed. Install it via 'pip install redis'")
|
||||
|
||||
def get_record(self, key):
|
||||
"""
|
||||
Function to get record from database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
connection = Redis(connection_pool=self.redis)
|
||||
result = connection.get(self.prefix+str(key))
|
||||
connection.close()
|
||||
if result: return json.loads(result)
|
||||
return
|
||||
|
||||
def set_record(self, key, value):
|
||||
"""
|
||||
Function to set record to database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
connection = Redis(connection_pool=self.redis)
|
||||
connection.set(self.prefix+str(key), json.dumps(value))
|
||||
connection.close()
|
||||
return True
|
||||
|
||||
def delete_record(self, key):
|
||||
"""
|
||||
Function to delete record from database.
|
||||
It has nothing to do with states.
|
||||
Made for backend compatibility
|
||||
"""
|
||||
connection = Redis(connection_pool=self.redis)
|
||||
connection.delete(self.prefix+str(key))
|
||||
connection.close()
|
||||
return True
|
||||
|
||||
def set_state(self, chat_id, user_id, state):
|
||||
"""
|
||||
Set state for a particular user in a chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['state'] = state
|
||||
else:
|
||||
response[user_id] = {'state': state, 'data': {}}
|
||||
else:
|
||||
response = {user_id: {'state': state, 'data': {}}}
|
||||
self.set_record(chat_id, response)
|
||||
|
||||
return True
|
||||
|
||||
def delete_state(self, chat_id, user_id):
|
||||
"""
|
||||
Delete state for a particular user in a chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
del response[user_id]
|
||||
if user_id == str(chat_id):
|
||||
self.delete_record(chat_id)
|
||||
return True
|
||||
else: self.set_record(chat_id, response)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_value(self, chat_id, user_id, key):
|
||||
"""
|
||||
Get value for a data of a user in a chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
if key in response[user_id]['data']:
|
||||
return response[user_id]['data'][key]
|
||||
return None
|
||||
|
||||
|
||||
def get_state(self, chat_id, user_id):
|
||||
"""
|
||||
Get state of a user in a chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
return response[user_id]['state']
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get data of particular user in a particular chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
return response[user_id]['data']
|
||||
return None
|
||||
|
||||
|
||||
def reset_data(self, chat_id, user_id):
|
||||
"""
|
||||
Reset data of a user in a chat.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'] = {}
|
||||
self.set_record(chat_id, response)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
def set_data(self, chat_id, user_id, key, value):
|
||||
"""
|
||||
Set data without interactive data.
|
||||
"""
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'][key] = value
|
||||
self.set_record(chat_id, response)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_interactive_data(self, chat_id, user_id):
|
||||
"""
|
||||
Get Data in interactive way.
|
||||
You can use with() with this function.
|
||||
"""
|
||||
return StateContext(self, chat_id, user_id)
|
||||
|
||||
def save(self, chat_id, user_id, data):
|
||||
response = self.get_record(chat_id)
|
||||
user_id = str(user_id)
|
||||
if response:
|
||||
if user_id in response:
|
||||
response[user_id]['data'] = dict(data, **response[user_id]['data'])
|
||||
self.set_record(chat_id, response)
|
||||
return True
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user