diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index 6c1fc4b..1b39761 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -1,5 +1,6 @@ from abc import ABC from typing import Optional, Union +from telebot.asyncio_handler_backends import State from telebot import types @@ -277,17 +278,29 @@ class StateFilter(AdvancedCustomFilter): async def check(self, message, text): if text == '*': return True + # needs to work with callbackquery + if isinstance(message, types.Message): + chat_id = message.chat.id + user_id = message.from_user.id + + if isinstance(message, types.CallbackQuery): + + chat_id = message.message.chat.id + user_id = message.from_user.id + message = message.message + + if isinstance(text, list): new_text = [] for i in text: - if isclass(i): i = i.name + if isinstance(i, State): i = i.name new_text.append(i) text = new_text - elif isinstance(text, object): + elif isinstance(text, State): text = text.name if message.chat.type == 'group': - group_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id) + group_state = await self.bot.current_states.get_state(user_id, chat_id) if group_state == text: return True elif group_state in text and type(text) is list: @@ -295,7 +308,7 @@ class StateFilter(AdvancedCustomFilter): else: - user_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id) + user_state = await self.bot.current_states.get_state(user_id, chat_id) if user_state == text: return True elif type(text) is list and user_state in text: diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index 0305673..8442be4 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -1,9 +1,12 @@ from abc import ABC from typing import Optional, Union +from telebot.handler_backends import State from telebot import types + + class SimpleCustomFilter(ABC): """ Simple Custom Filter base class. @@ -280,17 +283,32 @@ class StateFilter(AdvancedCustomFilter): def check(self, message, text): if text == '*': return True + + # needs to work with callbackquery + if isinstance(message, types.Message): + chat_id = message.chat.id + user_id = message.from_user.id + + if isinstance(message, types.CallbackQuery): + + chat_id = message.message.chat.id + user_id = message.from_user.id + message = message.message + + + if isinstance(text, list): new_text = [] for i in text: - if isclass(i): i = i.name + if isinstance(i, State): i = i.name new_text.append(i) text = new_text - elif isinstance(text, object): + elif isinstance(text, State): text = text.name + if message.chat.type == 'group': - group_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id) + group_state = self.bot.current_states.get_state(user_id, chat_id) if group_state == text: return True elif group_state in text and type(text) is list: @@ -298,7 +316,7 @@ class StateFilter(AdvancedCustomFilter): else: - user_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id) + user_state = self.bot.current_states.get_state(user_id, chat_id) if user_state == text: return True elif type(text) is list and user_state in text: