mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
created async TextFilter
This commit is contained in:
parent
5b1483f646
commit
5f7ccc8c9b
@ -6,10 +6,9 @@ with (message_handler, callback_query_handler, poll_handler)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from telebot import types
|
||||
from telebot.custom_filters import TextFilter
|
||||
from telebot.asyncio_filters import TextMatchFilter
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from telebot.asyncio_filters import TextMatchFilter, TextFilter
|
||||
|
||||
bot = AsyncTeleBot("")
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Union
|
||||
|
||||
from telebot.custom_filters import TextFilter
|
||||
from telebot import types
|
||||
|
||||
|
||||
class SimpleCustomFilter(ABC):
|
||||
@ -33,6 +34,88 @@ class AdvancedCustomFilter(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class TextFilter:
|
||||
"""
|
||||
Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||
|
||||
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[Union[list, tuple]] = None,
|
||||
starts_with: Optional[str] = None,
|
||||
ends_with: Optional[str] = None,
|
||||
ignore_case: bool = False):
|
||||
|
||||
"""
|
||||
:param equals: string, True if object's text is equal to passed string
|
||||
:param contains: list[str] or tuple[str], True if object's text is in list or tuple
|
||||
:param starts_with: string, True if object's text starts with passed string
|
||||
:param ends_with: string, True if object's text starts with passed string
|
||||
:param ignore_case: bool (default False), case insensitive
|
||||
"""
|
||||
|
||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||
if to_check == 0:
|
||||
raise ValueError('None of the check modes was specified')
|
||||
elif to_check > 1:
|
||||
raise ValueError('Only one check mode can be specified')
|
||||
elif contains:
|
||||
for i in contains:
|
||||
if not isinstance(i, str):
|
||||
raise ValueError(f"Invalid value '{i}' is in contains")
|
||||
elif starts_with and not isinstance(starts_with, str):
|
||||
raise ValueError("starts_with has to be a string")
|
||||
elif ends_with and not isinstance(ends_with, str):
|
||||
raise ValueError("ends_with has to be a string")
|
||||
|
||||
self.equals = equals
|
||||
self.contains = contains
|
||||
self.starts_with = starts_with
|
||||
self.ends_with = ends_with
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||
|
||||
if isinstance(obj, types.Poll):
|
||||
text = obj.question
|
||||
elif isinstance(obj, types.Message):
|
||||
text = obj.text or obj.caption
|
||||
elif isinstance(obj, types.CallbackQuery):
|
||||
text = obj.data
|
||||
elif isinstance(obj, types.InlineQuery):
|
||||
text = obj.query
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.equals:
|
||||
self.equals = self.equals.lower()
|
||||
elif self.contains:
|
||||
self.contains = tuple(i.lower() for i in self.contains)
|
||||
elif self.starts_with:
|
||||
self.starts_with = self.starts_with.lower()
|
||||
elif self.ends_with:
|
||||
self.ends_with = self.ends_with.lower()
|
||||
|
||||
if self.equals:
|
||||
return self.equals == text
|
||||
|
||||
if self.contains:
|
||||
return text in self.contains
|
||||
|
||||
if self.starts_with:
|
||||
return text.startswith(self.starts_with)
|
||||
|
||||
if self.ends_with:
|
||||
return text.endswith(self.ends_with)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TextMatchFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check Text message.
|
||||
@ -46,7 +129,7 @@ class TextMatchFilter(AdvancedCustomFilter):
|
||||
|
||||
async def check(self, message, text):
|
||||
if isinstance(text, TextFilter):
|
||||
return text.check(message)
|
||||
return await text.check(message)
|
||||
elif type(text) is list:
|
||||
return message.text in text
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user