diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index 5fa003d..a8b7180 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -36,16 +36,16 @@ class AdvancedCustomFilter(ABC): class TextFilter: """ - Advanced async text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) + Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll) - example of usage is in examples/asynchronous_telebot/custom_filters/advanced_text_filter.py + example of usage is in examples/custom_filters/advanced_text_filter.py """ def __init__(self, equals: Optional[str] = None, contains: Optional[Union[list, tuple]] = None, - starts_with: Optional[str] = None, - ends_with: Optional[str] = None, + starts_with: Optional[Union[str, list, tuple]] = None, + ends_with: Optional[Union[str, list, tuple]] = None, ignore_case: bool = False): """ @@ -59,26 +59,24 @@ class TextFilter: to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) if to_check == 0: raise ValueError('None of the check modes was specified') - elif to_check > 1: - raise ValueError('Only one check mode can be specified') - elif contains: - if not isinstance(contains, str) and not isinstance(contains, list) and not isinstance(contains, tuple): - raise ValueError("Incorrect contains value") - elif isinstance(contains, str): - contains = [contains] - elif isinstance(contains, list) or isinstance(contains, tuple): - contains = [i for i in contains if isinstance(i, str)] - elif starts_with and not isinstance(starts_with, str): - raise ValueError("starts_with has to be a string") - elif ends_with and not isinstance(ends_with, str): - raise ValueError("ends_with has to be a string") self.equals = equals - self.contains = contains - self.starts_with = starts_with - self.ends_with = ends_with + self.contains = self._check_iterable(contains) + self.starts_with = self._check_iterable(starts_with) + self.ends_with = self._check_iterable(ends_with) self.ignore_case = ignore_case + def _check_iterable(self, iterable): + if not iterable: + pass + elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple): + raise ValueError + elif isinstance(iterable, str): + iterable = [iterable] + elif isinstance(iterable, list) or isinstance(iterable, tuple): + iterable = [i for i in iterable if isinstance(i, str)] + return iterable + async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): if isinstance(obj, types.Poll): @@ -98,23 +96,35 @@ class TextFilter: if self.equals: self.equals = self.equals.lower() elif self.contains: - self.contains = tuple(i.lower() for i in self.contains) + self.contains = tuple(map(str.lower, self.contains)) elif self.starts_with: - self.starts_with = self.starts_with.lower() + self.starts_with = tuple(map(str.lower, self.starts_with)) elif self.ends_with: - self.ends_with = self.ends_with.lower() + self.ends_with = tuple(map(str.lower, self.ends_with)) if self.equals: - return self.equals == text + result = self.equals == text + if result: + return True + elif not result and not any((self.contains, self.starts_with, self.ends_with)): + return False if self.contains: - return any([i in text for i in self.contains]) + result = any([i in text for i in self.contains]) + if result: + return True + elif not result and not any((self.starts_with, self.ends_with)): + return False if self.starts_with: - return text.startswith(self.starts_with) + result = any([text.startswith(i) for i in self.starts_with]) + if result: + return True + elif not result and not self.ends_with: + return False if self.ends_with: - return text.endswith(self.ends_with) + return any([text.endswith(i) for i in self.ends_with]) return False diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index 6eea0f0..145cc74 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -44,8 +44,8 @@ class TextFilter: def __init__(self, equals: Optional[str] = None, contains: Optional[Union[list, tuple]] = None, - starts_with: Optional[str] = None, - ends_with: Optional[str] = None, + starts_with: Optional[Union[str, list, tuple]] = None, + ends_with: Optional[Union[str, list, tuple]] = None, ignore_case: bool = False): """ @@ -59,26 +59,24 @@ class TextFilter: to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with))) if to_check == 0: raise ValueError('None of the check modes was specified') - elif to_check > 1: - raise ValueError('Only one check mode can be specified') - elif contains: - if not isinstance(contains, str) and not isinstance(contains, list) and not isinstance(contains, tuple): - raise ValueError("Incorrect contains value") - elif isinstance(contains, str): - contains = [contains] - elif isinstance(contains, list) or isinstance(contains, tuple): - contains = [i for i in contains if isinstance(i, str)] - elif starts_with and not isinstance(starts_with, str): - raise ValueError("starts_with has to be a string") - elif ends_with and not isinstance(ends_with, str): - raise ValueError("ends_with has to be a string") self.equals = equals - self.contains = contains - self.starts_with = starts_with - self.ends_with = ends_with + self.contains = self._check_iterable(contains) + self.starts_with = self._check_iterable(starts_with) + self.ends_with = self._check_iterable(ends_with) self.ignore_case = ignore_case + def _check_iterable(self, iterable): + if not iterable: + pass + elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple): + raise ValueError + elif isinstance(iterable, str): + iterable = [iterable] + elif isinstance(iterable, list) or isinstance(iterable, tuple): + iterable = [i for i in iterable if isinstance(i, str)] + return iterable + def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]): if isinstance(obj, types.Poll): @@ -98,27 +96,38 @@ class TextFilter: if self.equals: self.equals = self.equals.lower() elif self.contains: - self.contains = tuple(i.lower() for i in self.contains) + self.contains = tuple(map(str.lower, self.contains)) elif self.starts_with: - self.starts_with = self.starts_with.lower() + self.starts_with = tuple(map(str.lower, self.starts_with)) elif self.ends_with: - self.ends_with = self.ends_with.lower() + self.ends_with = tuple(map(str.lower, self.ends_with)) if self.equals: - return self.equals == text + result = self.equals == text + if result: + return True + elif not result and not any((self.contains, self.starts_with, self.ends_with)): + return False if self.contains: - return any([i in text for i in self.contains]) + result = any([i in text for i in self.contains]) + if result: + return True + elif not result and not any((self.starts_with, self.ends_with)): + return False if self.starts_with: - return text.startswith(self.starts_with) + result = any([text.startswith(i) for i in self.starts_with]) + if result: + return True + elif not result and not self.ends_with: + return False if self.ends_with: - return text.endswith(self.ends_with) + return any([text.endswith(i) for i in self.ends_with]) return False - class TextMatchFilter(AdvancedCustomFilter): """ Filter to check Text message.