from abc import ABC from typing import Optional, Union from telebot.asyncio_handler_backends import State from telebot import types class SimpleCustomFilter(ABC): """ Simple Custom Filter base class. Create child class with check() method. Accepts only message, returns bool value, that is compared with given in handler. """ async def check(self, message): """ Perform a check. """ pass class AdvancedCustomFilter(ABC): """ Simple Custom Filter base class. Create child class with check() method. Accepts two parameters, returns bool: True - filter passed, False - filter failed. message: Message class text: Filter value given in handler """ async def check(self, message, text): """ Perform a check. """ pass class TextFilter: """ Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) example of usage is in examples/custom_filters/advanced_text_filter.py """ def __init__(self, equals: Optional[str] = None, contains: Optional[Union[list, tuple]] = None, starts_with: Optional[Union[str, list, tuple]] = None, ends_with: Optional[Union[str, list, tuple]] = None, ignore_case: bool = False): """ :param equals: string, True if object's text is equal to passed string :param contains: list[str] or tuple[str], True if any string element of iterable is in text :param starts_with: string, True if object's text starts with passed string :param ends_with: string, True if object's text starts with passed string :param ignore_case: bool (default False), case insensitive """ to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) if to_check == 0: raise ValueError('None of the check modes was specified') self.equals = equals self.contains = self._check_iterable(contains, filter_name='contains') self.starts_with = self._check_iterable(starts_with, filter_name='starts_with') self.ends_with = self._check_iterable(ends_with, filter_name='ends_with') self.ignore_case = ignore_case def _check_iterable(self, iterable, filter_name): if not iterable: pass elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple): raise ValueError(f"Incorrect value of {filter_name!r}") elif isinstance(iterable, str): iterable = [iterable] elif isinstance(iterable, list) or isinstance(iterable, tuple): iterable = [i for i in iterable if isinstance(i, str)] return iterable async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): if isinstance(obj, types.Poll): text = obj.question elif isinstance(obj, types.Message): text = obj.text or obj.caption elif isinstance(obj, types.CallbackQuery): text = obj.data elif isinstance(obj, types.InlineQuery): text = obj.query else: return False if self.ignore_case: text = text.lower() prepare_func = lambda string: str(string).lower() else: prepare_func = str if self.equals: result = prepare_func(self.equals) == text if result: return True elif not result and not any((self.contains, self.starts_with, self.ends_with)): return False if self.contains: result = any([prepare_func(i) in text for i in self.contains]) if result: return True elif not result and not any((self.starts_with, self.ends_with)): return False if self.starts_with: result = any([text.startswith(prepare_func(i)) for i in self.starts_with]) if result: return True elif not result and not self.ends_with: return False if self.ends_with: return any([text.endswith(prepare_func(i)) for i in self.ends_with]) return False class TextMatchFilter(AdvancedCustomFilter): """ Filter to check Text message. key: text Example: @bot.message_handler(text=['account']) """ key = 'text' async def check(self, message, text): if isinstance(text, TextFilter): return await text.check(message) elif type(text) is list: return message.text in text else: return text == message.text class TextContainsFilter(AdvancedCustomFilter): """ Filter to check Text message. key: text Example: # Will respond if any message.text contains word 'account' @bot.message_handler(text_contains=['account']) """ key = 'text_contains' async def check(self, message, text): if not isinstance(text, str) and not isinstance(text, list) and not isinstance(text, tuple): raise ValueError("Incorrect text_contains value") elif isinstance(text, str): text = [text] elif isinstance(text, list) or isinstance(text, tuple): text = [i for i in text if isinstance(i, str)] return any([i in message.text for i in text]) class TextStartsFilter(AdvancedCustomFilter): """ Filter to check whether message starts with some text. Example: # Will work if message.text starts with 'Sir'. @bot.message_handler(text_startswith='Sir') """ key = 'text_startswith' async def check(self, message, text): return message.text.startswith(text) class ChatFilter(AdvancedCustomFilter): """ Check whether chat_id corresponds to given chat_id. Example: @bot.message_handler(chat_id=[99999]) """ key = 'chat_id' async def check(self, message, text): return message.chat.id in text class ForwardFilter(SimpleCustomFilter): """ Check whether message was forwarded from channel or group. Example: @bot.message_handler(is_forwarded=True) """ key = 'is_forwarded' async def check(self, message): return message.forward_from_chat is not None class IsReplyFilter(SimpleCustomFilter): """ Check whether message is a reply. Example: @bot.message_handler(is_reply=True) """ key = 'is_reply' async def check(self, message): return message.reply_to_message is not None class LanguageFilter(AdvancedCustomFilter): """ Check users language_code. Example: @bot.message_handler(language_code=['ru']) """ key = 'language_code' async def check(self, message, text): if type(text) is list: return message.from_user.language_code in text else: return message.from_user.language_code == text class IsAdminFilter(SimpleCustomFilter): """ Check whether the user is administrator / owner of the chat. Example: @bot.message_handler(chat_types=['supergroup'], is_chat_admin=True) """ key = 'is_chat_admin' def __init__(self, bot): self._bot = bot async def check(self, message): result = await self._bot.get_chat_member(message.chat.id, message.from_user.id) return result.status in ['creator', 'administrator'] class StateFilter(AdvancedCustomFilter): """ Filter to check state. Example: @bot.message_handler(state=1) """ def __init__(self, bot): self.bot = bot key = 'state' async def check(self, message, text): if text == '*': return True # needs to work with callbackquery if isinstance(message, types.Message): chat_id = message.chat.id user_id = message.from_user.id if isinstance(message, types.CallbackQuery): chat_id = message.message.chat.id user_id = message.from_user.id message = message.message if isinstance(text, list): new_text = [] for i in text: if isinstance(i, State): i = i.name new_text.append(i) text = new_text elif isinstance(text, State): text = text.name if message.chat.type == 'group': group_state = await self.bot.current_states.get_state(user_id, chat_id) if group_state == text: return True elif group_state in text and type(text) is list: return True else: user_state = await self.bot.current_states.get_state(user_id, chat_id) if user_state == text: return True elif type(text) is list and user_state in text: return True class IsDigitFilter(SimpleCustomFilter): """ Filter to check whether the string is made up of only digits. Example: @bot.message_handler(is_digit=True) """ key = 'is_digit' async def check(self, message): return message.text.isdigit()