mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
created async TextFilter
This commit is contained in:
parent
5b1483f646
commit
5f7ccc8c9b
@ -6,10 +6,9 @@ with (message_handler, callback_query_handler, poll_handler)
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from telebot.async_telebot import AsyncTeleBot
|
|
||||||
from telebot import types
|
from telebot import types
|
||||||
from telebot.custom_filters import TextFilter
|
from telebot.async_telebot import AsyncTeleBot
|
||||||
from telebot.asyncio_filters import TextMatchFilter
|
from telebot.asyncio_filters import TextMatchFilter, TextFilter
|
||||||
|
|
||||||
bot = AsyncTeleBot("")
|
bot = AsyncTeleBot("")
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from telebot.custom_filters import TextFilter
|
from telebot import types
|
||||||
|
|
||||||
|
|
||||||
class SimpleCustomFilter(ABC):
|
class SimpleCustomFilter(ABC):
|
||||||
@ -33,6 +34,88 @@ class AdvancedCustomFilter(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TextFilter:
|
||||||
|
"""
|
||||||
|
Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||||
|
|
||||||
|
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
equals: Optional[str] = None,
|
||||||
|
contains: Optional[Union[list, tuple]] = None,
|
||||||
|
starts_with: Optional[str] = None,
|
||||||
|
ends_with: Optional[str] = None,
|
||||||
|
ignore_case: bool = False):
|
||||||
|
|
||||||
|
"""
|
||||||
|
:param equals: string, True if object's text is equal to passed string
|
||||||
|
:param contains: list[str] or tuple[str], True if object's text is in list or tuple
|
||||||
|
:param starts_with: string, True if object's text starts with passed string
|
||||||
|
:param ends_with: string, True if object's text starts with passed string
|
||||||
|
:param ignore_case: bool (default False), case insensitive
|
||||||
|
"""
|
||||||
|
|
||||||
|
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||||
|
if to_check == 0:
|
||||||
|
raise ValueError('None of the check modes was specified')
|
||||||
|
elif to_check > 1:
|
||||||
|
raise ValueError('Only one check mode can be specified')
|
||||||
|
elif contains:
|
||||||
|
for i in contains:
|
||||||
|
if not isinstance(i, str):
|
||||||
|
raise ValueError(f"Invalid value '{i}' is in contains")
|
||||||
|
elif starts_with and not isinstance(starts_with, str):
|
||||||
|
raise ValueError("starts_with has to be a string")
|
||||||
|
elif ends_with and not isinstance(ends_with, str):
|
||||||
|
raise ValueError("ends_with has to be a string")
|
||||||
|
|
||||||
|
self.equals = equals
|
||||||
|
self.contains = contains
|
||||||
|
self.starts_with = starts_with
|
||||||
|
self.ends_with = ends_with
|
||||||
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
|
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||||
|
|
||||||
|
if isinstance(obj, types.Poll):
|
||||||
|
text = obj.question
|
||||||
|
elif isinstance(obj, types.Message):
|
||||||
|
text = obj.text or obj.caption
|
||||||
|
elif isinstance(obj, types.CallbackQuery):
|
||||||
|
text = obj.data
|
||||||
|
elif isinstance(obj, types.InlineQuery):
|
||||||
|
text = obj.query
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.ignore_case:
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
if self.equals:
|
||||||
|
self.equals = self.equals.lower()
|
||||||
|
elif self.contains:
|
||||||
|
self.contains = tuple(i.lower() for i in self.contains)
|
||||||
|
elif self.starts_with:
|
||||||
|
self.starts_with = self.starts_with.lower()
|
||||||
|
elif self.ends_with:
|
||||||
|
self.ends_with = self.ends_with.lower()
|
||||||
|
|
||||||
|
if self.equals:
|
||||||
|
return self.equals == text
|
||||||
|
|
||||||
|
if self.contains:
|
||||||
|
return text in self.contains
|
||||||
|
|
||||||
|
if self.starts_with:
|
||||||
|
return text.startswith(self.starts_with)
|
||||||
|
|
||||||
|
if self.ends_with:
|
||||||
|
return text.endswith(self.ends_with)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TextMatchFilter(AdvancedCustomFilter):
|
class TextMatchFilter(AdvancedCustomFilter):
|
||||||
"""
|
"""
|
||||||
Filter to check Text message.
|
Filter to check Text message.
|
||||||
@ -46,7 +129,7 @@ class TextMatchFilter(AdvancedCustomFilter):
|
|||||||
|
|
||||||
async def check(self, message, text):
|
async def check(self, message, text):
|
||||||
if isinstance(text, TextFilter):
|
if isinstance(text, TextFilter):
|
||||||
return text.check(message)
|
return await text.check(message)
|
||||||
elif type(text) is list:
|
elif type(text) is list:
|
||||||
return message.text in text
|
return message.text in text
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user