From 5f7ccc8c9b5d14b99a7e2669830c3bf242380ae3 Mon Sep 17 00:00:00 2001 From: abdullaev388 Date: Sat, 12 Feb 2022 17:33:29 +0500 Subject: [PATCH] created async TextFilter --- .../custom_filters/advanced_text_filter.py | 5 +- telebot/asyncio_filters.py | 87 ++++++++++++++++++- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py b/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py index 184da31..4c363b5 100644 --- a/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py +++ b/examples/asynchronous_telebot/custom_filters/advanced_text_filter.py @@ -6,10 +6,9 @@ with (message_handler, callback_query_handler, poll_handler) """ import asyncio -from telebot.async_telebot import AsyncTeleBot from telebot import types -from telebot.custom_filters import TextFilter -from telebot.asyncio_filters import TextMatchFilter +from telebot.async_telebot import AsyncTeleBot +from telebot.asyncio_filters import TextMatchFilter, TextFilter bot = AsyncTeleBot("") diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index bacc26f..9944ff4 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -1,6 +1,7 @@ from abc import ABC +from typing import Optional, Union -from telebot.custom_filters import TextFilter +from telebot import types class SimpleCustomFilter(ABC): @@ -33,6 +34,88 @@ class AdvancedCustomFilter(ABC): pass +class TextFilter: + """ + Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) + + example of usage is in examples/custom_filters/advanced_text_filter.py + """ + + def __init__(self, + equals: Optional[str] = None, + contains: Optional[Union[list, tuple]] = None, + starts_with: Optional[str] = None, + ends_with: Optional[str] = None, + ignore_case: bool = False): + + """ + :param equals: string, True if object's text is equal to passed string + :param contains: list[str] or tuple[str], True if object's text is in list or tuple + :param starts_with: string, True if object's text starts with passed string + :param ends_with: string, True if object's text starts with passed string + :param ignore_case: bool (default False), case insensitive + """ + + to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) + if to_check == 0: + raise ValueError('None of the check modes was specified') + elif to_check > 1: + raise ValueError('Only one check mode can be specified') + elif contains: + for i in contains: + if not isinstance(i, str): + raise ValueError(f"Invalid value '{i}' is in contains") + elif starts_with and not isinstance(starts_with, str): + raise ValueError("starts_with has to be a string") + elif ends_with and not isinstance(ends_with, str): + raise ValueError("ends_with has to be a string") + + self.equals = equals + self.contains = contains + self.starts_with = starts_with + self.ends_with = ends_with + self.ignore_case = ignore_case + + async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): + + if isinstance(obj, types.Poll): + text = obj.question + elif isinstance(obj, types.Message): + text = obj.text or obj.caption + elif isinstance(obj, types.CallbackQuery): + text = obj.data + elif isinstance(obj, types.InlineQuery): + text = obj.query + else: + return False + + if self.ignore_case: + text = text.lower() + + if self.equals: + self.equals = self.equals.lower() + elif self.contains: + self.contains = tuple(i.lower() for i in self.contains) + elif self.starts_with: + self.starts_with = self.starts_with.lower() + elif self.ends_with: + self.ends_with = self.ends_with.lower() + + if self.equals: + return self.equals == text + + if self.contains: + return text in self.contains + + if self.starts_with: + return text.startswith(self.starts_with) + + if self.ends_with: + return text.endswith(self.ends_with) + + return False + + class TextMatchFilter(AdvancedCustomFilter): """ Filter to check Text message. @@ -46,7 +129,7 @@ class TextMatchFilter(AdvancedCustomFilter): async def check(self, message, text): if isinstance(text, TextFilter): - return text.check(message) + return await text.check(message) elif type(text) is list: return message.text in text else: