created async TextFilter

This commit is contained in:
abdullaev388 2022-02-12 17:33:29 +05:00
parent 5b1483f646
commit 5f7ccc8c9b
2 changed files with 87 additions and 5 deletions

View File

@ -6,10 +6,9 @@ with (message_handler, callback_query_handler, poll_handler)
"""
import asyncio
from telebot.async_telebot import AsyncTeleBot
from telebot import types
from telebot.custom_filters import TextFilter
from telebot.asyncio_filters import TextMatchFilter
from telebot.async_telebot import AsyncTeleBot
from telebot.asyncio_filters import TextMatchFilter, TextFilter
bot = AsyncTeleBot("")

View File

@ -1,6 +1,7 @@
from abc import ABC
from typing import Optional, Union
from telebot.custom_filters import TextFilter
from telebot import types
class SimpleCustomFilter(ABC):
@ -33,6 +34,88 @@ class AdvancedCustomFilter(ABC):
pass
class TextFilter:
"""
Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
example of usage is in examples/custom_filters/advanced_text_filter.py
"""
def __init__(self,
equals: Optional[str] = None,
contains: Optional[Union[list, tuple]] = None,
starts_with: Optional[str] = None,
ends_with: Optional[str] = None,
ignore_case: bool = False):
"""
:param equals: string, True if object's text is equal to passed string
:param contains: list[str] or tuple[str], True if object's text is in list or tuple
:param starts_with: string, True if object's text starts with passed string
:param ends_with: string, True if object's text starts with passed string
:param ignore_case: bool (default False), case insensitive
"""
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
if to_check == 0:
raise ValueError('None of the check modes was specified')
elif to_check > 1:
raise ValueError('Only one check mode can be specified')
elif contains:
for i in contains:
if not isinstance(i, str):
raise ValueError(f"Invalid value '{i}' is in contains")
elif starts_with and not isinstance(starts_with, str):
raise ValueError("starts_with has to be a string")
elif ends_with and not isinstance(ends_with, str):
raise ValueError("ends_with has to be a string")
self.equals = equals
self.contains = contains
self.starts_with = starts_with
self.ends_with = ends_with
self.ignore_case = ignore_case
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
if isinstance(obj, types.Poll):
text = obj.question
elif isinstance(obj, types.Message):
text = obj.text or obj.caption
elif isinstance(obj, types.CallbackQuery):
text = obj.data
elif isinstance(obj, types.InlineQuery):
text = obj.query
else:
return False
if self.ignore_case:
text = text.lower()
if self.equals:
self.equals = self.equals.lower()
elif self.contains:
self.contains = tuple(i.lower() for i in self.contains)
elif self.starts_with:
self.starts_with = self.starts_with.lower()
elif self.ends_with:
self.ends_with = self.ends_with.lower()
if self.equals:
return self.equals == text
if self.contains:
return text in self.contains
if self.starts_with:
return text.startswith(self.starts_with)
if self.ends_with:
return text.endswith(self.ends_with)
return False
class TextMatchFilter(AdvancedCustomFilter):
"""
Filter to check Text message.
@ -46,7 +129,7 @@ class TextMatchFilter(AdvancedCustomFilter):
async def check(self, message, text):
if isinstance(text, TextFilter):
return text.check(message)
return await text.check(message)
elif type(text) is list:
return message.text in text
else: