mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
new advanced TextFilter was added && An example demostrating TextFilter usage
This commit is contained in:
@ -1,4 +1,9 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Union
|
||||
|
||||
from telebot import types
|
||||
|
||||
|
||||
class SimpleCustomFilter(ABC):
|
||||
"""
|
||||
Simple Custom Filter base class.
|
||||
@ -29,6 +34,84 @@ class AdvancedCustomFilter(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class TextFilter:
|
||||
"""
|
||||
Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||
|
||||
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[Union[list, tuple]] = None,
|
||||
starts_with: Optional[str] = None,
|
||||
ends_with: Optional[str] = None,
|
||||
ignore_case: bool = False):
|
||||
|
||||
"""
|
||||
:param equals: string, True if object's text is equal to passed string
|
||||
:param contains: list[str] or tuple[str], True if object's text is in list or tuple
|
||||
:param starts_with: string, True if object's text starts with passed string
|
||||
:param ends_with: string, True if object's text starts with passed string
|
||||
:param ignore_case: bool (default False), case insensitive
|
||||
"""
|
||||
|
||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||
if to_check == 0:
|
||||
raise ValueError('None of the check modes was specified')
|
||||
elif to_check > 1:
|
||||
raise ValueError('Only one check mode can be specified')
|
||||
elif contains:
|
||||
for i in contains:
|
||||
if not isinstance(i, str):
|
||||
raise ValueError(f"Invalid value '{i}' is in contains")
|
||||
elif starts_with and not isinstance(starts_with, str):
|
||||
raise ValueError("starts_with has to be a string")
|
||||
elif ends_with and not isinstance(ends_with, str):
|
||||
raise ValueError("ends_with has to be a string")
|
||||
|
||||
self.equals = equals
|
||||
self.contains = contains
|
||||
self.starts_with = starts_with
|
||||
self.ends_with = ends_with
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||
|
||||
if isinstance(obj, types.Poll):
|
||||
text = obj.question
|
||||
elif isinstance(obj, types.Message):
|
||||
text = obj.text or obj.caption
|
||||
elif isinstance(obj, types.CallbackQuery):
|
||||
text = obj.data
|
||||
elif isinstance(obj, types.InlineQuery):
|
||||
text = obj.query
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.equals:
|
||||
if self.ignore_case:
|
||||
return self.equals.lower() == text.lower()
|
||||
return self.equals == text
|
||||
|
||||
if self.contains:
|
||||
if self.ignore_case:
|
||||
return text.lower() in (i.lower() for i in self.contains)
|
||||
return text in self.contains
|
||||
|
||||
if self.starts_with:
|
||||
if self.ignore_case:
|
||||
return text.lower().startswith(self.starts_with)
|
||||
return text.startswith(self.starts_with)
|
||||
|
||||
if self.ends_with:
|
||||
if self.ignore_case:
|
||||
return text.lower().endswith(self.ends_with)
|
||||
return text.endswith(self.ends_with)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TextMatchFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check Text message.
|
||||
@ -41,8 +124,11 @@ class TextMatchFilter(AdvancedCustomFilter):
|
||||
key = 'text'
|
||||
|
||||
def check(self, message, text):
|
||||
if type(text) is list:return message.text in text
|
||||
else: return text == message.text
|
||||
if type(text) is list:
|
||||
return message.text in text
|
||||
else:
|
||||
return text == message.text
|
||||
|
||||
|
||||
class TextContainsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -59,6 +145,7 @@ class TextContainsFilter(AdvancedCustomFilter):
|
||||
def check(self, message, text):
|
||||
return text in message.text
|
||||
|
||||
|
||||
class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check whether message starts with some text.
|
||||
@ -69,8 +156,10 @@ class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'text_startswith'
|
||||
|
||||
def check(self, message, text):
|
||||
return message.text.startswith(text)
|
||||
return message.text.startswith(text)
|
||||
|
||||
|
||||
class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -81,9 +170,11 @@ class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'chat_id'
|
||||
|
||||
def check(self, message, text):
|
||||
return message.chat.id in text
|
||||
|
||||
|
||||
class ForwardFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message was forwarded from channel or group.
|
||||
@ -98,6 +189,7 @@ class ForwardFilter(SimpleCustomFilter):
|
||||
def check(self, message):
|
||||
return message.forward_from_chat is not None
|
||||
|
||||
|
||||
class IsReplyFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message is a reply.
|
||||
@ -113,7 +205,6 @@ class IsReplyFilter(SimpleCustomFilter):
|
||||
return message.reply_to_message is not None
|
||||
|
||||
|
||||
|
||||
class LanguageFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Check users language_code.
|
||||
@ -126,8 +217,11 @@ class LanguageFilter(AdvancedCustomFilter):
|
||||
key = 'language_code'
|
||||
|
||||
def check(self, message, text):
|
||||
if type(text) is list:return message.from_user.language_code in text
|
||||
else: return message.from_user.language_code == text
|
||||
if type(text) is list:
|
||||
return message.from_user.language_code in text
|
||||
else:
|
||||
return message.from_user.language_code == text
|
||||
|
||||
|
||||
class IsAdminFilter(SimpleCustomFilter):
|
||||
"""
|
||||
@ -145,6 +239,7 @@ class IsAdminFilter(SimpleCustomFilter):
|
||||
def check(self, message):
|
||||
return self._bot.get_chat_member(message.chat.id, message.from_user.id).status in ['creator', 'administrator']
|
||||
|
||||
|
||||
class StateFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check state.
|
||||
@ -152,8 +247,10 @@ class StateFilter(AdvancedCustomFilter):
|
||||
Example:
|
||||
@bot.message_handler(state=1)
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
key = 'state'
|
||||
|
||||
def check(self, message, text):
|
||||
@ -170,14 +267,16 @@ class StateFilter(AdvancedCustomFilter):
|
||||
return True
|
||||
elif group_state in text and type(text) is list:
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
else:
|
||||
user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id)
|
||||
user_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id)
|
||||
if user_state == text:
|
||||
return True
|
||||
elif type(text) is list and user_state in text:
|
||||
return True
|
||||
|
||||
|
||||
class IsDigitFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Filter to check whether the string is made up of only digits.
|
||||
|
Reference in New Issue
Block a user