mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
multiple check patterns && multiple startwith, endswith fields
This commit is contained in:
parent
6e4f2e19d6
commit
6822f18cbb
@ -36,16 +36,16 @@ class AdvancedCustomFilter(ABC):
|
|||||||
|
|
||||||
class TextFilter:
|
class TextFilter:
|
||||||
"""
|
"""
|
||||||
Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||||
|
|
||||||
example of usage is in examples/asynchronous_telebot/custom_filters/advanced_text_filter.py
|
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
equals: Optional[str] = None,
|
equals: Optional[str] = None,
|
||||||
contains: Optional[Union[list, tuple]] = None,
|
contains: Optional[Union[list, tuple]] = None,
|
||||||
starts_with: Optional[str] = None,
|
starts_with: Optional[Union[str, list, tuple]] = None,
|
||||||
ends_with: Optional[str] = None,
|
ends_with: Optional[Union[str, list, tuple]] = None,
|
||||||
ignore_case: bool = False):
|
ignore_case: bool = False):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -59,26 +59,24 @@ class TextFilter:
|
|||||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||||
if to_check == 0:
|
if to_check == 0:
|
||||||
raise ValueError('None of the check modes was specified')
|
raise ValueError('None of the check modes was specified')
|
||||||
elif to_check > 1:
|
|
||||||
raise ValueError('Only one check mode can be specified')
|
|
||||||
elif contains:
|
|
||||||
if not isinstance(contains, str) and not isinstance(contains, list) and not isinstance(contains, tuple):
|
|
||||||
raise ValueError("Incorrect contains value")
|
|
||||||
elif isinstance(contains, str):
|
|
||||||
contains = [contains]
|
|
||||||
elif isinstance(contains, list) or isinstance(contains, tuple):
|
|
||||||
contains = [i for i in contains if isinstance(i, str)]
|
|
||||||
elif starts_with and not isinstance(starts_with, str):
|
|
||||||
raise ValueError("starts_with has to be a string")
|
|
||||||
elif ends_with and not isinstance(ends_with, str):
|
|
||||||
raise ValueError("ends_with has to be a string")
|
|
||||||
|
|
||||||
self.equals = equals
|
self.equals = equals
|
||||||
self.contains = contains
|
self.contains = self._check_iterable(contains)
|
||||||
self.starts_with = starts_with
|
self.starts_with = self._check_iterable(starts_with)
|
||||||
self.ends_with = ends_with
|
self.ends_with = self._check_iterable(ends_with)
|
||||||
self.ignore_case = ignore_case
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
|
def _check_iterable(self, iterable):
|
||||||
|
if not iterable:
|
||||||
|
pass
|
||||||
|
elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple):
|
||||||
|
raise ValueError
|
||||||
|
elif isinstance(iterable, str):
|
||||||
|
iterable = [iterable]
|
||||||
|
elif isinstance(iterable, list) or isinstance(iterable, tuple):
|
||||||
|
iterable = [i for i in iterable if isinstance(i, str)]
|
||||||
|
return iterable
|
||||||
|
|
||||||
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||||
|
|
||||||
if isinstance(obj, types.Poll):
|
if isinstance(obj, types.Poll):
|
||||||
@ -98,23 +96,35 @@ class TextFilter:
|
|||||||
if self.equals:
|
if self.equals:
|
||||||
self.equals = self.equals.lower()
|
self.equals = self.equals.lower()
|
||||||
elif self.contains:
|
elif self.contains:
|
||||||
self.contains = tuple(i.lower() for i in self.contains)
|
self.contains = tuple(map(str.lower, self.contains))
|
||||||
elif self.starts_with:
|
elif self.starts_with:
|
||||||
self.starts_with = self.starts_with.lower()
|
self.starts_with = tuple(map(str.lower, self.starts_with))
|
||||||
elif self.ends_with:
|
elif self.ends_with:
|
||||||
self.ends_with = self.ends_with.lower()
|
self.ends_with = tuple(map(str.lower, self.ends_with))
|
||||||
|
|
||||||
if self.equals:
|
if self.equals:
|
||||||
return self.equals == text
|
result = self.equals == text
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
elif not result and not any((self.contains, self.starts_with, self.ends_with)):
|
||||||
|
return False
|
||||||
|
|
||||||
if self.contains:
|
if self.contains:
|
||||||
return any([i in text for i in self.contains])
|
result = any([i in text for i in self.contains])
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
elif not result and not any((self.starts_with, self.ends_with)):
|
||||||
|
return False
|
||||||
|
|
||||||
if self.starts_with:
|
if self.starts_with:
|
||||||
return text.startswith(self.starts_with)
|
result = any([text.startswith(i) for i in self.starts_with])
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
elif not result and not self.ends_with:
|
||||||
|
return False
|
||||||
|
|
||||||
if self.ends_with:
|
if self.ends_with:
|
||||||
return text.endswith(self.ends_with)
|
return any([text.endswith(i) for i in self.ends_with])
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -44,8 +44,8 @@ class TextFilter:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
equals: Optional[str] = None,
|
equals: Optional[str] = None,
|
||||||
contains: Optional[Union[list, tuple]] = None,
|
contains: Optional[Union[list, tuple]] = None,
|
||||||
starts_with: Optional[str] = None,
|
starts_with: Optional[Union[str, list, tuple]] = None,
|
||||||
ends_with: Optional[str] = None,
|
ends_with: Optional[Union[str, list, tuple]] = None,
|
||||||
ignore_case: bool = False):
|
ignore_case: bool = False):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -59,26 +59,24 @@ class TextFilter:
|
|||||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||||
if to_check == 0:
|
if to_check == 0:
|
||||||
raise ValueError('None of the check modes was specified')
|
raise ValueError('None of the check modes was specified')
|
||||||
elif to_check > 1:
|
|
||||||
raise ValueError('Only one check mode can be specified')
|
|
||||||
elif contains:
|
|
||||||
if not isinstance(contains, str) and not isinstance(contains, list) and not isinstance(contains, tuple):
|
|
||||||
raise ValueError("Incorrect contains value")
|
|
||||||
elif isinstance(contains, str):
|
|
||||||
contains = [contains]
|
|
||||||
elif isinstance(contains, list) or isinstance(contains, tuple):
|
|
||||||
contains = [i for i in contains if isinstance(i, str)]
|
|
||||||
elif starts_with and not isinstance(starts_with, str):
|
|
||||||
raise ValueError("starts_with has to be a string")
|
|
||||||
elif ends_with and not isinstance(ends_with, str):
|
|
||||||
raise ValueError("ends_with has to be a string")
|
|
||||||
|
|
||||||
self.equals = equals
|
self.equals = equals
|
||||||
self.contains = contains
|
self.contains = self._check_iterable(contains)
|
||||||
self.starts_with = starts_with
|
self.starts_with = self._check_iterable(starts_with)
|
||||||
self.ends_with = ends_with
|
self.ends_with = self._check_iterable(ends_with)
|
||||||
self.ignore_case = ignore_case
|
self.ignore_case = ignore_case
|
||||||
|
|
||||||
|
def _check_iterable(self, iterable):
|
||||||
|
if not iterable:
|
||||||
|
pass
|
||||||
|
elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple):
|
||||||
|
raise ValueError
|
||||||
|
elif isinstance(iterable, str):
|
||||||
|
iterable = [iterable]
|
||||||
|
elif isinstance(iterable, list) or isinstance(iterable, tuple):
|
||||||
|
iterable = [i for i in iterable if isinstance(i, str)]
|
||||||
|
return iterable
|
||||||
|
|
||||||
def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||||
|
|
||||||
if isinstance(obj, types.Poll):
|
if isinstance(obj, types.Poll):
|
||||||
@ -98,26 +96,37 @@ class TextFilter:
|
|||||||
if self.equals:
|
if self.equals:
|
||||||
self.equals = self.equals.lower()
|
self.equals = self.equals.lower()
|
||||||
elif self.contains:
|
elif self.contains:
|
||||||
self.contains = tuple(i.lower() for i in self.contains)
|
self.contains = tuple(map(str.lower, self.contains))
|
||||||
elif self.starts_with:
|
elif self.starts_with:
|
||||||
self.starts_with = self.starts_with.lower()
|
self.starts_with = tuple(map(str.lower, self.starts_with))
|
||||||
elif self.ends_with:
|
elif self.ends_with:
|
||||||
self.ends_with = self.ends_with.lower()
|
self.ends_with = tuple(map(str.lower, self.ends_with))
|
||||||
|
|
||||||
if self.equals:
|
if self.equals:
|
||||||
return self.equals == text
|
result = self.equals == text
|
||||||
|
if result:
|
||||||
if self.contains:
|
return True
|
||||||
return any([i in text for i in self.contains])
|
elif not result and not any((self.contains, self.starts_with, self.ends_with)):
|
||||||
|
|
||||||
if self.starts_with:
|
|
||||||
return text.startswith(self.starts_with)
|
|
||||||
|
|
||||||
if self.ends_with:
|
|
||||||
return text.endswith(self.ends_with)
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if self.contains:
|
||||||
|
result = any([i in text for i in self.contains])
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
elif not result and not any((self.starts_with, self.ends_with)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.starts_with:
|
||||||
|
result = any([text.startswith(i) for i in self.starts_with])
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
elif not result and not self.ends_with:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.ends_with:
|
||||||
|
return any([text.endswith(i) for i in self.ends_with])
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
class TextMatchFilter(AdvancedCustomFilter):
|
class TextMatchFilter(AdvancedCustomFilter):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user