diff --git a/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/filters.py b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/filters.py new file mode 100644 index 0000000..7c5c304 --- /dev/null +++ b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/filters.py @@ -0,0 +1,26 @@ +from telebot import types +from telebot.async_telebot import AsyncTeleBot +from telebot.asyncio_filters import AdvancedCustomFilter +from telebot.callback_data import CallbackData, CallbackDataFilter + +calendar_factory = CallbackData("year", "month", prefix="calendar") +calendar_zoom = CallbackData("year", prefix="calendar_zoom") + + +class CalendarCallbackFilter(AdvancedCustomFilter): + key = 'calendar_config' + + async def check(self, call: types.CallbackQuery, config: CallbackDataFilter): + return config.check(query=call) + + +class CalendarZoomCallbackFilter(AdvancedCustomFilter): + key = 'calendar_zoom_config' + + async def check(self, call: types.CallbackQuery, config: CallbackDataFilter): + return config.check(query=call) + + +def bind_filters(bot: AsyncTeleBot): + bot.add_custom_filter(CalendarCallbackFilter()) + bot.add_custom_filter(CalendarZoomCallbackFilter()) diff --git a/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/keyboards.py b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/keyboards.py new file mode 100644 index 0000000..1aee88c --- /dev/null +++ b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/keyboards.py @@ -0,0 +1,92 @@ +import calendar +from datetime import date, timedelta + +from filters import calendar_factory, calendar_zoom +from telebot.types import InlineKeyboardMarkup, InlineKeyboardButton + +EMTPY_FIELD = '1' +WEEK_DAYS = [calendar.day_abbr[i] for i in range(7)] +MONTHS = [(i, calendar.month_name[i]) for i in range(1, 13)] + + +def generate_calendar_days(year: int, month: int): + keyboard = InlineKeyboardMarkup(row_width=7) + today = date.today() + + keyboard.add( + InlineKeyboardButton( + text=date(year=year, month=month, day=1).strftime('%b %Y'), + callback_data=EMTPY_FIELD + ) + ) + keyboard.add(*[ + InlineKeyboardButton( + text=day, + callback_data=EMTPY_FIELD + ) + for day in WEEK_DAYS + ]) + + for week in calendar.Calendar().monthdayscalendar(year=year, month=month): + week_buttons = [] + for day in week: + day_name = ' ' + if day == today.day and today.year == year and today.month == month: + day_name = '🔘' + elif day != 0: + day_name = str(day) + week_buttons.append( + InlineKeyboardButton( + text=day_name, + callback_data=EMTPY_FIELD + ) + ) + keyboard.add(*week_buttons) + + previous_date = date(year=year, month=month, day=1) - timedelta(days=1) + next_date = date(year=year, month=month, day=1) + timedelta(days=31) + + keyboard.add( + InlineKeyboardButton( + text='Previous month', + callback_data=calendar_factory.new(year=previous_date.year, month=previous_date.month) + ), + InlineKeyboardButton( + text='Zoom out', + callback_data=calendar_zoom.new(year=year) + ), + InlineKeyboardButton( + text='Next month', + callback_data=calendar_factory.new(year=next_date.year, month=next_date.month) + ), + ) + + return keyboard + + +def generate_calendar_months(year: int): + keyboard = InlineKeyboardMarkup(row_width=3) + keyboard.add( + InlineKeyboardButton( + text=date(year=year, month=1, day=1).strftime('Year %Y'), + callback_data=EMTPY_FIELD + ) + ) + keyboard.add(*[ + InlineKeyboardButton( + text=month, + callback_data=calendar_factory.new(year=year, month=month_number) + ) + for month_number, month in MONTHS + ]) + keyboard.add( + InlineKeyboardButton( + text='Previous year', + callback_data=calendar_zoom.new(year=year - 1) + ), + InlineKeyboardButton( + text='Next year', + callback_data=calendar_zoom.new(year=year + 1) + ) + ) + return keyboard diff --git a/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/main.py b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/main.py new file mode 100644 index 0000000..0041474 --- /dev/null +++ b/examples/asynchronous_telebot/callback_data_examples/advanced_calendar_example/main.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +""" +This Example will show you an advanced usage of CallbackData. +In this example calendar was implemented +""" +import asyncio +from datetime import date + +from filters import calendar_factory, calendar_zoom, bind_filters +from keyboards import generate_calendar_days, generate_calendar_months, EMTPY_FIELD +from telebot import types +from telebot.async_telebot import AsyncTeleBot + +API_TOKEN = '' +bot = AsyncTeleBot(API_TOKEN) + + +@bot.message_handler(commands='start') +async def start_command_handler(message: types.Message): + await bot.send_message(message.chat.id, + f"Hello {message.from_user.first_name}. This bot is an example of calendar keyboard." + "\nPress /calendar to see it.") + + +@bot.message_handler(commands='calendar') +async def calendar_command_handler(message: types.Message): + now = date.today() + await bot.send_message(message.chat.id, 'Calendar', + reply_markup=generate_calendar_days(year=now.year, month=now.month)) + + +@bot.callback_query_handler(func=None, calendar_config=calendar_factory.filter()) +async def calendar_action_handler(call: types.CallbackQuery): + callback_data: dict = calendar_factory.parse(callback_data=call.data) + year, month = int(callback_data['year']), int(callback_data['month']) + + await bot.edit_message_reply_markup(call.message.chat.id, call.message.id, + reply_markup=generate_calendar_days(year=year, month=month)) + + +@bot.callback_query_handler(func=None, calendar_zoom_config=calendar_zoom.filter()) +async def calendar_zoom_out_handler(call: types.CallbackQuery): + callback_data: dict = calendar_zoom.parse(callback_data=call.data) + year = int(callback_data.get('year')) + + await bot.edit_message_reply_markup(call.message.chat.id, call.message.id, + reply_markup=generate_calendar_months(year=year)) + + +@bot.callback_query_handler(func=lambda call: call.data == EMTPY_FIELD) +async def callback_empty_field_handler(call: types.CallbackQuery): + await bot.answer_callback_query(call.id) + + +if __name__ == '__main__': + bind_filters(bot) + asyncio.run(bot.infinity_polling()) diff --git a/examples/asynchronous_telebot/CallbackData_example.py b/examples/asynchronous_telebot/callback_data_examples/simple_products_example.py similarity index 100% rename from examples/asynchronous_telebot/CallbackData_example.py rename to examples/asynchronous_telebot/callback_data_examples/simple_products_example.py diff --git a/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py b/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py new file mode 100644 index 0000000..1200a6c --- /dev/null +++ b/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +This Example will show you usage of TextFilter +In this example you will see how to use TextFilter +with (message_handler, callback_query_handler, poll_handler) +""" +import asyncio + +from telebot import types +from telebot.async_telebot import AsyncTeleBot +from telebot.asyncio_filters import TextMatchFilter, TextFilter, IsReplyFilter + +bot = AsyncTeleBot("") + + +@bot.message_handler(text=TextFilter(equals='hello')) +async def hello_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(equals='hello', ignore_case=True)) +async def hello_handler_ignore_case(message: types.Message): + await bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(contains=['good', 'bad'])) +async def contains_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(contains=['good', 'bad'], ignore_case=True)) +async def contains_handler_ignore_case(message: types.Message): + await bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(starts_with='st')) # stArk, steve, stONE +async def starts_with_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(starts_with='st', ignore_case=True)) # STark, sTeve, stONE +async def starts_with_handler_ignore_case(message: types.Message): + await bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(ends_with='ay')) # wednesday, SUNday, WeekDay +async def ends_with_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(ends_with='ay', ignore_case=True)) # wednesdAY, sundAy, WeekdaY +async def ends_with_handler_ignore_case(message: types.Message): + await bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(equals='/callback')) +async def send_callback(message: types.Message): + keyboard = types.InlineKeyboardMarkup( + keyboard=[ + [types.InlineKeyboardButton(text='callback data', callback_data='example')], + [types.InlineKeyboardButton(text='ignore case callback data', callback_data='ExAmPLe')] + ] + ) + await bot.send_message(message.chat.id, message.text, reply_markup=keyboard) + + +@bot.callback_query_handler(func=None, text=TextFilter(equals='example')) +async def callback_query_handler(call: types.CallbackQuery): + await bot.answer_callback_query(call.id, call.data, show_alert=True) + + +@bot.callback_query_handler(func=None, text=TextFilter(equals='example', ignore_case=True)) +async def callback_query_handler_ignore_case(call: types.CallbackQuery): + await bot.answer_callback_query(call.id, call.data + " ignore case", show_alert=True) + + +@bot.message_handler(text=TextFilter(equals='/poll')) +async def send_poll(message: types.Message): + await bot.send_poll(message.chat.id, question='When do you prefer to work?', options=['Morning', 'Night']) + await bot.send_poll(message.chat.id, question='WHEN DO you pRefeR to worK?', options=['Morning', 'Night']) + + +@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?')) +async def poll_question_handler(poll: types.Poll): + print(poll.question) + + +@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?', ignore_case=True)) +async def poll_question_handler_ignore_case(poll: types.Poll): + print(poll.question + ' ignore case') + + +# either hi or contains one of (привет, salom) +@bot.message_handler(text=TextFilter(equals="hi", contains=('привет', 'salom'), ignore_case=True)) +async def multiple_patterns_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +# starts with one of (mi, mea) for ex. minor, milk, meal, meat +@bot.message_handler(text=TextFilter(starts_with=['mi', 'mea'], ignore_case=True)) +async def multiple_starts_with_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +# ends with one of (es, on) for ex. Jones, Davies, Johnson, Wilson +@bot.message_handler(text=TextFilter(ends_with=['es', 'on'], ignore_case=True)) +async def multiple_ends_with_handler(message: types.Message): + await bot.send_message(message.chat.id, message.text) + + +# !ban /ban .ban !бан /бан .бан +@bot.message_handler(is_reply=True, + text=TextFilter(starts_with=('!', '/', '.'), ends_with=['ban', 'бан'], ignore_case=True)) +async def ban_command_handler(message: types.Message): + if len(message.text) == 4 and message.chat.type != 'private': + try: + await bot.ban_chat_member(message.chat.id, message.reply_to_message.from_user.id) + await bot.reply_to(message.reply_to_message, 'Banned.') + except Exception as err: + print(err.args) + return + + +if __name__ == '__main__': + bot.add_custom_filter(TextMatchFilter()) + bot.add_custom_filter(IsReplyFilter()) + asyncio.run(bot.polling()) diff --git a/examples/custom_filters/advanced_text_filter.py b/examples/custom_filters/advanced_text_filter.py new file mode 100644 index 0000000..2b01685 --- /dev/null +++ b/examples/custom_filters/advanced_text_filter.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +""" +This Example will show you usage of TextFilter +In this example you will see how to use TextFilter +with (message_handler, callback_query_handler, poll_handler) +""" + +from telebot import TeleBot, types +from telebot.custom_filters import TextFilter, TextMatchFilter, IsReplyFilter + +bot = TeleBot("") + + +@bot.message_handler(text=TextFilter(equals='hello')) +def hello_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(equals='hello', ignore_case=True)) +def hello_handler_ignore_case(message: types.Message): + bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(contains=['good', 'bad'])) +def contains_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(contains=['good', 'bad'], ignore_case=True)) +def contains_handler_ignore_case(message: types.Message): + bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(starts_with='st')) # stArk, steve, stONE +def starts_with_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(starts_with='st', ignore_case=True)) # STark, sTeve, stONE +def starts_with_handler_ignore_case(message: types.Message): + bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(ends_with='ay')) # wednesday, SUNday, WeekDay +def ends_with_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +@bot.message_handler(text=TextFilter(ends_with='ay', ignore_case=True)) # wednesdAY, sundAy, WeekdaY +def ends_with_handler_ignore_case(message: types.Message): + bot.send_message(message.chat.id, message.text + ' ignore case') + + +@bot.message_handler(text=TextFilter(equals='/callback')) +def send_callback(message: types.Message): + keyboard = types.InlineKeyboardMarkup( + keyboard=[ + [types.InlineKeyboardButton(text='callback data', callback_data='example')], + [types.InlineKeyboardButton(text='ignore case callback data', callback_data='ExAmPLe')] + ] + ) + bot.send_message(message.chat.id, message.text, reply_markup=keyboard) + + +@bot.callback_query_handler(func=None, text=TextFilter(equals='example')) +def callback_query_handler(call: types.CallbackQuery): + bot.answer_callback_query(call.id, call.data, show_alert=True) + + +@bot.callback_query_handler(func=None, text=TextFilter(equals='example', ignore_case=True)) +def callback_query_handler_ignore_case(call: types.CallbackQuery): + bot.answer_callback_query(call.id, call.data + " ignore case", show_alert=True) + + +@bot.message_handler(text=TextFilter(equals='/poll')) +def send_poll(message: types.Message): + bot.send_poll(message.chat.id, question='When do you prefer to work?', options=['Morning', 'Night']) + bot.send_poll(message.chat.id, question='WHEN DO you pRefeR to worK?', options=['Morning', 'Night']) + + +@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?')) +def poll_question_handler(poll: types.Poll): + print(poll.question) + + +@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?', ignore_case=True)) +def poll_question_handler_ignore_case(poll: types.Poll): + print(poll.question + ' ignore case') + + +# either hi or contains one of (привет, salom) +@bot.message_handler(text=TextFilter(equals="hi", contains=('привет', 'salom'), ignore_case=True)) +def multiple_patterns_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +# starts with one of (mi, mea) for ex. minor, milk, meal, meat +@bot.message_handler(text=TextFilter(starts_with=['mi', 'mea'], ignore_case=True)) +def multiple_starts_with_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +# ends with one of (es, on) for ex. Jones, Davies, Johnson, Wilson +@bot.message_handler(text=TextFilter(ends_with=['es', 'on'], ignore_case=True)) +def multiple_ends_with_handler(message: types.Message): + bot.send_message(message.chat.id, message.text) + + +# !ban /ban .ban !бан /бан .бан +@bot.message_handler(is_reply=True, + text=TextFilter(starts_with=('!', '/', '.'), ends_with=['ban', 'бан'], ignore_case=True)) +def ban_command_handler(message: types.Message): + if len(message.text) == 4 and message.chat.type != 'private': + try: + bot.ban_chat_member(message.chat.id, message.reply_to_message.from_user.id) + bot.reply_to(message.reply_to_message, 'Banned.') + except Exception as err: + print(err.args) + return + + +if __name__ == '__main__': + bot.add_custom_filter(TextMatchFilter()) + bot.add_custom_filter(IsReplyFilter()) + bot.infinity_polling() diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index 417b110..cb0120b 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -1,4 +1,8 @@ from abc import ABC +from typing import Optional, Union + +from telebot import types + class SimpleCustomFilter(ABC): """ @@ -30,6 +34,101 @@ class AdvancedCustomFilter(ABC): pass +class TextFilter: + """ + Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) + + example of usage is in examples/custom_filters/advanced_text_filter.py + """ + + def __init__(self, + equals: Optional[str] = None, + contains: Optional[Union[list, tuple]] = None, + starts_with: Optional[Union[str, list, tuple]] = None, + ends_with: Optional[Union[str, list, tuple]] = None, + ignore_case: bool = False): + + """ + :param equals: string, True if object's text is equal to passed string + :param contains: list[str] or tuple[str], True if any string element of iterable is in text + :param starts_with: string, True if object's text starts with passed string + :param ends_with: string, True if object's text starts with passed string + :param ignore_case: bool (default False), case insensitive + """ + + to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) + if to_check == 0: + raise ValueError('None of the check modes was specified') + + self.equals = equals + self.contains = self._check_iterable(contains, filter_name='contains') + self.starts_with = self._check_iterable(starts_with, filter_name='starts_with') + self.ends_with = self._check_iterable(ends_with, filter_name='ends_with') + self.ignore_case = ignore_case + + def _check_iterable(self, iterable, filter_name): + if not iterable: + pass + elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple): + raise ValueError(f"Incorrect value of {filter_name!r}") + elif isinstance(iterable, str): + iterable = [iterable] + elif isinstance(iterable, list) or isinstance(iterable, tuple): + iterable = [i for i in iterable if isinstance(i, str)] + return iterable + + async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): + + if isinstance(obj, types.Poll): + text = obj.question + elif isinstance(obj, types.Message): + text = obj.text or obj.caption + elif isinstance(obj, types.CallbackQuery): + text = obj.data + elif isinstance(obj, types.InlineQuery): + text = obj.query + else: + return False + + if self.ignore_case: + text = text.lower() + + if self.equals: + self.equals = self.equals.lower() + elif self.contains: + self.contains = tuple(map(str.lower, self.contains)) + elif self.starts_with: + self.starts_with = tuple(map(str.lower, self.starts_with)) + elif self.ends_with: + self.ends_with = tuple(map(str.lower, self.ends_with)) + + if self.equals: + result = self.equals == text + if result: + return True + elif not result and not any((self.contains, self.starts_with, self.ends_with)): + return False + + if self.contains: + result = any([i in text for i in self.contains]) + if result: + return True + elif not result and not any((self.starts_with, self.ends_with)): + return False + + if self.starts_with: + result = any([text.startswith(i) for i in self.starts_with]) + if result: + return True + elif not result and not self.ends_with: + return False + + if self.ends_with: + return any([text.endswith(i) for i in self.ends_with]) + + return False + + class TextMatchFilter(AdvancedCustomFilter): """ Filter to check Text message. @@ -42,8 +141,13 @@ class TextMatchFilter(AdvancedCustomFilter): key = 'text' async def check(self, message, text): - if type(text) is list:return message.text in text - else: return text == message.text + if isinstance(text, TextFilter): + return await text.check(message) + elif type(text) is list: + return message.text in text + else: + return text == message.text + class TextContainsFilter(AdvancedCustomFilter): """ @@ -58,7 +162,15 @@ class TextContainsFilter(AdvancedCustomFilter): key = 'text_contains' async def check(self, message, text): - return text in message.text + if not isinstance(text, str) and not isinstance(text, list) and not isinstance(text, tuple): + raise ValueError("Incorrect text_contains value") + elif isinstance(text, str): + text = [text] + elif isinstance(text, list) or isinstance(text, tuple): + text = [i for i in text if isinstance(i, str)] + + return any([i in message.text for i in text]) + class TextStartsFilter(AdvancedCustomFilter): """ @@ -70,8 +182,10 @@ class TextStartsFilter(AdvancedCustomFilter): """ key = 'text_startswith' + async def check(self, message, text): - return message.text.startswith(text) + return message.text.startswith(text) + class ChatFilter(AdvancedCustomFilter): """ @@ -82,9 +196,11 @@ class ChatFilter(AdvancedCustomFilter): """ key = 'chat_id' + async def check(self, message, text): return message.chat.id in text + class ForwardFilter(SimpleCustomFilter): """ Check whether message was forwarded from channel or group. @@ -99,6 +215,7 @@ class ForwardFilter(SimpleCustomFilter): async def check(self, message): return message.forward_from_chat is not None + class IsReplyFilter(SimpleCustomFilter): """ Check whether message is a reply. @@ -114,7 +231,6 @@ class IsReplyFilter(SimpleCustomFilter): return message.reply_to_message is not None - class LanguageFilter(AdvancedCustomFilter): """ Check users language_code. @@ -127,8 +243,11 @@ class LanguageFilter(AdvancedCustomFilter): key = 'language_code' async def check(self, message, text): - if type(text) is list:return message.from_user.language_code in text - else: return message.from_user.language_code == text + if type(text) is list: + return message.from_user.language_code in text + else: + return message.from_user.language_code == text + class IsAdminFilter(SimpleCustomFilter): """ @@ -147,6 +266,7 @@ class IsAdminFilter(SimpleCustomFilter): result = await self._bot.get_chat_member(message.chat.id, message.from_user.id) return result.status in ['creator', 'administrator'] + class StateFilter(AdvancedCustomFilter): """ Filter to check state. @@ -154,8 +274,10 @@ class StateFilter(AdvancedCustomFilter): Example: @bot.message_handler(state=1) """ + def __init__(self, bot): self.bot = bot + key = 'state' async def check(self, message, text): @@ -166,22 +288,23 @@ class StateFilter(AdvancedCustomFilter): text = new_text elif isinstance(text, object): text = text.name - + if message.chat.type == 'group': group_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id) if group_state == text: return True elif group_state in text and type(text) is list: return True - - + + else: - user_state = await self.bot.current_states.get_state(message.chat.id,message.from_user.id) + user_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id) if user_state == text: return True elif type(text) is list and user_state in text: return True + class IsDigitFilter(SimpleCustomFilter): """ Filter to check whether the string is made up of only digits. diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index d95ecd3..e6a1531 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -1,4 +1,9 @@ from abc import ABC +from typing import Optional, Union + +from telebot import types + + class SimpleCustomFilter(ABC): """ Simple Custom Filter base class. @@ -29,6 +34,100 @@ class AdvancedCustomFilter(ABC): pass +class TextFilter: + """ + Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) + + example of usage is in examples/custom_filters/advanced_text_filter.py + """ + + def __init__(self, + equals: Optional[str] = None, + contains: Optional[Union[list, tuple]] = None, + starts_with: Optional[Union[str, list, tuple]] = None, + ends_with: Optional[Union[str, list, tuple]] = None, + ignore_case: bool = False): + + """ + :param equals: string, True if object's text is equal to passed string + :param contains: list[str] or tuple[str], True if any string element of iterable is in text + :param starts_with: string, True if object's text starts with passed string + :param ends_with: string, True if object's text starts with passed string + :param ignore_case: bool (default False), case insensitive + """ + + to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) + if to_check == 0: + raise ValueError('None of the check modes was specified') + + self.equals = equals + self.contains = self._check_iterable(contains, filter_name='contains') + self.starts_with = self._check_iterable(starts_with, filter_name='starts_with') + self.ends_with = self._check_iterable(ends_with, filter_name='ends_with') + self.ignore_case = ignore_case + + def _check_iterable(self, iterable, filter_name: str): + if not iterable: + pass + elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple): + raise ValueError(f"Incorrect value of {filter_name!r}") + elif isinstance(iterable, str): + iterable = [iterable] + elif isinstance(iterable, list) or isinstance(iterable, tuple): + iterable = [i for i in iterable if isinstance(i, str)] + return iterable + + def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): + + if isinstance(obj, types.Poll): + text = obj.question + elif isinstance(obj, types.Message): + text = obj.text or obj.caption + elif isinstance(obj, types.CallbackQuery): + text = obj.data + elif isinstance(obj, types.InlineQuery): + text = obj.query + else: + return False + + if self.ignore_case: + text = text.lower() + + if self.equals: + self.equals = self.equals.lower() + elif self.contains: + self.contains = tuple(map(str.lower, self.contains)) + elif self.starts_with: + self.starts_with = tuple(map(str.lower, self.starts_with)) + elif self.ends_with: + self.ends_with = tuple(map(str.lower, self.ends_with)) + + if self.equals: + result = self.equals == text + if result: + return True + elif not result and not any((self.contains, self.starts_with, self.ends_with)): + return False + + if self.contains: + result = any([i in text for i in self.contains]) + if result: + return True + elif not result and not any((self.starts_with, self.ends_with)): + return False + + if self.starts_with: + result = any([text.startswith(i) for i in self.starts_with]) + if result: + return True + elif not result and not self.ends_with: + return False + + if self.ends_with: + return any([text.endswith(i) for i in self.ends_with]) + + return False + class TextMatchFilter(AdvancedCustomFilter): """ Filter to check Text message. @@ -41,8 +140,13 @@ class TextMatchFilter(AdvancedCustomFilter): key = 'text' def check(self, message, text): - if type(text) is list:return message.text in text - else: return text == message.text + if isinstance(text, TextFilter): + return text.check(message) + elif type(text) is list: + return message.text in text + else: + return text == message.text + class TextContainsFilter(AdvancedCustomFilter): """ @@ -57,7 +161,15 @@ class TextContainsFilter(AdvancedCustomFilter): key = 'text_contains' def check(self, message, text): - return text in message.text + if not isinstance(text, str) and not isinstance(text, list) and not isinstance(text, tuple): + raise ValueError("Incorrect text_contains value") + elif isinstance(text, str): + text = [text] + elif isinstance(text, list) or isinstance(text, tuple): + text = [i for i in text if isinstance(i, str)] + + return any([i in message.text for i in text]) + class TextStartsFilter(AdvancedCustomFilter): """ @@ -69,8 +181,10 @@ class TextStartsFilter(AdvancedCustomFilter): """ key = 'text_startswith' + def check(self, message, text): - return message.text.startswith(text) + return message.text.startswith(text) + class ChatFilter(AdvancedCustomFilter): """ @@ -81,9 +195,11 @@ class ChatFilter(AdvancedCustomFilter): """ key = 'chat_id' + def check(self, message, text): return message.chat.id in text + class ForwardFilter(SimpleCustomFilter): """ Check whether message was forwarded from channel or group. @@ -98,6 +214,7 @@ class ForwardFilter(SimpleCustomFilter): def check(self, message): return message.forward_from_chat is not None + class IsReplyFilter(SimpleCustomFilter): """ Check whether message is a reply. @@ -113,7 +230,6 @@ class IsReplyFilter(SimpleCustomFilter): return message.reply_to_message is not None - class LanguageFilter(AdvancedCustomFilter): """ Check users language_code. @@ -126,8 +242,11 @@ class LanguageFilter(AdvancedCustomFilter): key = 'language_code' def check(self, message, text): - if type(text) is list:return message.from_user.language_code in text - else: return message.from_user.language_code == text + if type(text) is list: + return message.from_user.language_code in text + else: + return message.from_user.language_code == text + class IsAdminFilter(SimpleCustomFilter): """ @@ -145,6 +264,7 @@ class IsAdminFilter(SimpleCustomFilter): def check(self, message): return self._bot.get_chat_member(message.chat.id, message.from_user.id).status in ['creator', 'administrator'] + class StateFilter(AdvancedCustomFilter): """ Filter to check state. @@ -152,8 +272,10 @@ class StateFilter(AdvancedCustomFilter): Example: @bot.message_handler(state=1) """ + def __init__(self, bot): self.bot = bot + key = 'state' def check(self, message, text): @@ -170,14 +292,16 @@ class StateFilter(AdvancedCustomFilter): return True elif group_state in text and type(text) is list: return True - - + + else: - user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id) + user_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id) if user_state == text: return True elif type(text) is list and user_state in text: return True + + class IsDigitFilter(SimpleCustomFilter): """ Filter to check whether the string is made up of only digits.