mirror of
https://github.com/eternnoir/pyTelegramBotAPI.git
synced 2023-08-10 21:12:57 +03:00
Merge pull request #1449 from abdullaev388/master
new advanced TextFilter was added && An example demostrating TextFilt…
This commit is contained in:
commit
0ef8d04ed2
@ -0,0 +1,26 @@
|
||||
from telebot import types
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from telebot.asyncio_filters import AdvancedCustomFilter
|
||||
from telebot.callback_data import CallbackData, CallbackDataFilter
|
||||
|
||||
calendar_factory = CallbackData("year", "month", prefix="calendar")
|
||||
calendar_zoom = CallbackData("year", prefix="calendar_zoom")
|
||||
|
||||
|
||||
class CalendarCallbackFilter(AdvancedCustomFilter):
|
||||
key = 'calendar_config'
|
||||
|
||||
async def check(self, call: types.CallbackQuery, config: CallbackDataFilter):
|
||||
return config.check(query=call)
|
||||
|
||||
|
||||
class CalendarZoomCallbackFilter(AdvancedCustomFilter):
|
||||
key = 'calendar_zoom_config'
|
||||
|
||||
async def check(self, call: types.CallbackQuery, config: CallbackDataFilter):
|
||||
return config.check(query=call)
|
||||
|
||||
|
||||
def bind_filters(bot: AsyncTeleBot):
|
||||
bot.add_custom_filter(CalendarCallbackFilter())
|
||||
bot.add_custom_filter(CalendarZoomCallbackFilter())
|
@ -0,0 +1,92 @@
|
||||
import calendar
|
||||
from datetime import date, timedelta
|
||||
|
||||
from filters import calendar_factory, calendar_zoom
|
||||
from telebot.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||||
|
||||
EMTPY_FIELD = '1'
|
||||
WEEK_DAYS = [calendar.day_abbr[i] for i in range(7)]
|
||||
MONTHS = [(i, calendar.month_name[i]) for i in range(1, 13)]
|
||||
|
||||
|
||||
def generate_calendar_days(year: int, month: int):
|
||||
keyboard = InlineKeyboardMarkup(row_width=7)
|
||||
today = date.today()
|
||||
|
||||
keyboard.add(
|
||||
InlineKeyboardButton(
|
||||
text=date(year=year, month=month, day=1).strftime('%b %Y'),
|
||||
callback_data=EMTPY_FIELD
|
||||
)
|
||||
)
|
||||
keyboard.add(*[
|
||||
InlineKeyboardButton(
|
||||
text=day,
|
||||
callback_data=EMTPY_FIELD
|
||||
)
|
||||
for day in WEEK_DAYS
|
||||
])
|
||||
|
||||
for week in calendar.Calendar().monthdayscalendar(year=year, month=month):
|
||||
week_buttons = []
|
||||
for day in week:
|
||||
day_name = ' '
|
||||
if day == today.day and today.year == year and today.month == month:
|
||||
day_name = '🔘'
|
||||
elif day != 0:
|
||||
day_name = str(day)
|
||||
week_buttons.append(
|
||||
InlineKeyboardButton(
|
||||
text=day_name,
|
||||
callback_data=EMTPY_FIELD
|
||||
)
|
||||
)
|
||||
keyboard.add(*week_buttons)
|
||||
|
||||
previous_date = date(year=year, month=month, day=1) - timedelta(days=1)
|
||||
next_date = date(year=year, month=month, day=1) + timedelta(days=31)
|
||||
|
||||
keyboard.add(
|
||||
InlineKeyboardButton(
|
||||
text='Previous month',
|
||||
callback_data=calendar_factory.new(year=previous_date.year, month=previous_date.month)
|
||||
),
|
||||
InlineKeyboardButton(
|
||||
text='Zoom out',
|
||||
callback_data=calendar_zoom.new(year=year)
|
||||
),
|
||||
InlineKeyboardButton(
|
||||
text='Next month',
|
||||
callback_data=calendar_factory.new(year=next_date.year, month=next_date.month)
|
||||
),
|
||||
)
|
||||
|
||||
return keyboard
|
||||
|
||||
|
||||
def generate_calendar_months(year: int):
|
||||
keyboard = InlineKeyboardMarkup(row_width=3)
|
||||
keyboard.add(
|
||||
InlineKeyboardButton(
|
||||
text=date(year=year, month=1, day=1).strftime('Year %Y'),
|
||||
callback_data=EMTPY_FIELD
|
||||
)
|
||||
)
|
||||
keyboard.add(*[
|
||||
InlineKeyboardButton(
|
||||
text=month,
|
||||
callback_data=calendar_factory.new(year=year, month=month_number)
|
||||
)
|
||||
for month_number, month in MONTHS
|
||||
])
|
||||
keyboard.add(
|
||||
InlineKeyboardButton(
|
||||
text='Previous year',
|
||||
callback_data=calendar_zoom.new(year=year - 1)
|
||||
),
|
||||
InlineKeyboardButton(
|
||||
text='Next year',
|
||||
callback_data=calendar_zoom.new(year=year + 1)
|
||||
)
|
||||
)
|
||||
return keyboard
|
@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This Example will show you an advanced usage of CallbackData.
|
||||
In this example calendar was implemented
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import date
|
||||
|
||||
from filters import calendar_factory, calendar_zoom, bind_filters
|
||||
from keyboards import generate_calendar_days, generate_calendar_months, EMTPY_FIELD
|
||||
from telebot import types
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
|
||||
API_TOKEN = ''
|
||||
bot = AsyncTeleBot(API_TOKEN)
|
||||
|
||||
|
||||
@bot.message_handler(commands='start')
|
||||
async def start_command_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id,
|
||||
f"Hello {message.from_user.first_name}. This bot is an example of calendar keyboard."
|
||||
"\nPress /calendar to see it.")
|
||||
|
||||
|
||||
@bot.message_handler(commands='calendar')
|
||||
async def calendar_command_handler(message: types.Message):
|
||||
now = date.today()
|
||||
await bot.send_message(message.chat.id, 'Calendar',
|
||||
reply_markup=generate_calendar_days(year=now.year, month=now.month))
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, calendar_config=calendar_factory.filter())
|
||||
async def calendar_action_handler(call: types.CallbackQuery):
|
||||
callback_data: dict = calendar_factory.parse(callback_data=call.data)
|
||||
year, month = int(callback_data['year']), int(callback_data['month'])
|
||||
|
||||
await bot.edit_message_reply_markup(call.message.chat.id, call.message.id,
|
||||
reply_markup=generate_calendar_days(year=year, month=month))
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, calendar_zoom_config=calendar_zoom.filter())
|
||||
async def calendar_zoom_out_handler(call: types.CallbackQuery):
|
||||
callback_data: dict = calendar_zoom.parse(callback_data=call.data)
|
||||
year = int(callback_data.get('year'))
|
||||
|
||||
await bot.edit_message_reply_markup(call.message.chat.id, call.message.id,
|
||||
reply_markup=generate_calendar_months(year=year))
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=lambda call: call.data == EMTPY_FIELD)
|
||||
async def callback_empty_field_handler(call: types.CallbackQuery):
|
||||
await bot.answer_callback_query(call.id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bind_filters(bot)
|
||||
asyncio.run(bot.infinity_polling())
|
@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This Example will show you usage of TextFilter
|
||||
In this example you will see how to use TextFilter
|
||||
with (message_handler, callback_query_handler, poll_handler)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from telebot import types
|
||||
from telebot.async_telebot import AsyncTeleBot
|
||||
from telebot.asyncio_filters import TextMatchFilter, TextFilter, IsReplyFilter
|
||||
|
||||
bot = AsyncTeleBot("")
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='hello'))
|
||||
async def hello_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='hello', ignore_case=True))
|
||||
async def hello_handler_ignore_case(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(contains=['good', 'bad']))
|
||||
async def contains_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(contains=['good', 'bad'], ignore_case=True))
|
||||
async def contains_handler_ignore_case(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(starts_with='st')) # stArk, steve, stONE
|
||||
async def starts_with_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(starts_with='st', ignore_case=True)) # STark, sTeve, stONE
|
||||
async def starts_with_handler_ignore_case(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(ends_with='ay')) # wednesday, SUNday, WeekDay
|
||||
async def ends_with_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(ends_with='ay', ignore_case=True)) # wednesdAY, sundAy, WeekdaY
|
||||
async def ends_with_handler_ignore_case(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='/callback'))
|
||||
async def send_callback(message: types.Message):
|
||||
keyboard = types.InlineKeyboardMarkup(
|
||||
keyboard=[
|
||||
[types.InlineKeyboardButton(text='callback data', callback_data='example')],
|
||||
[types.InlineKeyboardButton(text='ignore case callback data', callback_data='ExAmPLe')]
|
||||
]
|
||||
)
|
||||
await bot.send_message(message.chat.id, message.text, reply_markup=keyboard)
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, text=TextFilter(equals='example'))
|
||||
async def callback_query_handler(call: types.CallbackQuery):
|
||||
await bot.answer_callback_query(call.id, call.data, show_alert=True)
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, text=TextFilter(equals='example', ignore_case=True))
|
||||
async def callback_query_handler_ignore_case(call: types.CallbackQuery):
|
||||
await bot.answer_callback_query(call.id, call.data + " ignore case", show_alert=True)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='/poll'))
|
||||
async def send_poll(message: types.Message):
|
||||
await bot.send_poll(message.chat.id, question='When do you prefer to work?', options=['Morning', 'Night'])
|
||||
await bot.send_poll(message.chat.id, question='WHEN DO you pRefeR to worK?', options=['Morning', 'Night'])
|
||||
|
||||
|
||||
@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?'))
|
||||
async def poll_question_handler(poll: types.Poll):
|
||||
print(poll.question)
|
||||
|
||||
|
||||
@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?', ignore_case=True))
|
||||
async def poll_question_handler_ignore_case(poll: types.Poll):
|
||||
print(poll.question + ' ignore case')
|
||||
|
||||
|
||||
# either hi or contains one of (привет, salom)
|
||||
@bot.message_handler(text=TextFilter(equals="hi", contains=('привет', 'salom'), ignore_case=True))
|
||||
async def multiple_patterns_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# starts with one of (mi, mea) for ex. minor, milk, meal, meat
|
||||
@bot.message_handler(text=TextFilter(starts_with=['mi', 'mea'], ignore_case=True))
|
||||
async def multiple_starts_with_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# ends with one of (es, on) for ex. Jones, Davies, Johnson, Wilson
|
||||
@bot.message_handler(text=TextFilter(ends_with=['es', 'on'], ignore_case=True))
|
||||
async def multiple_ends_with_handler(message: types.Message):
|
||||
await bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# !ban /ban .ban !бан /бан .бан
|
||||
@bot.message_handler(is_reply=True,
|
||||
text=TextFilter(starts_with=('!', '/', '.'), ends_with=['ban', 'бан'], ignore_case=True))
|
||||
async def ban_command_handler(message: types.Message):
|
||||
if len(message.text) == 4 and message.chat.type != 'private':
|
||||
try:
|
||||
await bot.ban_chat_member(message.chat.id, message.reply_to_message.from_user.id)
|
||||
await bot.reply_to(message.reply_to_message, 'Banned.')
|
||||
except Exception as err:
|
||||
print(err.args)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bot.add_custom_filter(TextMatchFilter())
|
||||
bot.add_custom_filter(IsReplyFilter())
|
||||
asyncio.run(bot.polling())
|
125
examples/custom_filters/advanced_text_filter.py
Normal file
125
examples/custom_filters/advanced_text_filter.py
Normal file
@ -0,0 +1,125 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This Example will show you usage of TextFilter
|
||||
In this example you will see how to use TextFilter
|
||||
with (message_handler, callback_query_handler, poll_handler)
|
||||
"""
|
||||
|
||||
from telebot import TeleBot, types
|
||||
from telebot.custom_filters import TextFilter, TextMatchFilter, IsReplyFilter
|
||||
|
||||
bot = TeleBot("")
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='hello'))
|
||||
def hello_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='hello', ignore_case=True))
|
||||
def hello_handler_ignore_case(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(contains=['good', 'bad']))
|
||||
def contains_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(contains=['good', 'bad'], ignore_case=True))
|
||||
def contains_handler_ignore_case(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(starts_with='st')) # stArk, steve, stONE
|
||||
def starts_with_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(starts_with='st', ignore_case=True)) # STark, sTeve, stONE
|
||||
def starts_with_handler_ignore_case(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(ends_with='ay')) # wednesday, SUNday, WeekDay
|
||||
def ends_with_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(ends_with='ay', ignore_case=True)) # wednesdAY, sundAy, WeekdaY
|
||||
def ends_with_handler_ignore_case(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text + ' ignore case')
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='/callback'))
|
||||
def send_callback(message: types.Message):
|
||||
keyboard = types.InlineKeyboardMarkup(
|
||||
keyboard=[
|
||||
[types.InlineKeyboardButton(text='callback data', callback_data='example')],
|
||||
[types.InlineKeyboardButton(text='ignore case callback data', callback_data='ExAmPLe')]
|
||||
]
|
||||
)
|
||||
bot.send_message(message.chat.id, message.text, reply_markup=keyboard)
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, text=TextFilter(equals='example'))
|
||||
def callback_query_handler(call: types.CallbackQuery):
|
||||
bot.answer_callback_query(call.id, call.data, show_alert=True)
|
||||
|
||||
|
||||
@bot.callback_query_handler(func=None, text=TextFilter(equals='example', ignore_case=True))
|
||||
def callback_query_handler_ignore_case(call: types.CallbackQuery):
|
||||
bot.answer_callback_query(call.id, call.data + " ignore case", show_alert=True)
|
||||
|
||||
|
||||
@bot.message_handler(text=TextFilter(equals='/poll'))
|
||||
def send_poll(message: types.Message):
|
||||
bot.send_poll(message.chat.id, question='When do you prefer to work?', options=['Morning', 'Night'])
|
||||
bot.send_poll(message.chat.id, question='WHEN DO you pRefeR to worK?', options=['Morning', 'Night'])
|
||||
|
||||
|
||||
@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?'))
|
||||
def poll_question_handler(poll: types.Poll):
|
||||
print(poll.question)
|
||||
|
||||
|
||||
@bot.poll_handler(func=None, text=TextFilter(equals='When do you prefer to work?', ignore_case=True))
|
||||
def poll_question_handler_ignore_case(poll: types.Poll):
|
||||
print(poll.question + ' ignore case')
|
||||
|
||||
|
||||
# either hi or contains one of (привет, salom)
|
||||
@bot.message_handler(text=TextFilter(equals="hi", contains=('привет', 'salom'), ignore_case=True))
|
||||
def multiple_patterns_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# starts with one of (mi, mea) for ex. minor, milk, meal, meat
|
||||
@bot.message_handler(text=TextFilter(starts_with=['mi', 'mea'], ignore_case=True))
|
||||
def multiple_starts_with_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# ends with one of (es, on) for ex. Jones, Davies, Johnson, Wilson
|
||||
@bot.message_handler(text=TextFilter(ends_with=['es', 'on'], ignore_case=True))
|
||||
def multiple_ends_with_handler(message: types.Message):
|
||||
bot.send_message(message.chat.id, message.text)
|
||||
|
||||
|
||||
# !ban /ban .ban !бан /бан .бан
|
||||
@bot.message_handler(is_reply=True,
|
||||
text=TextFilter(starts_with=('!', '/', '.'), ends_with=['ban', 'бан'], ignore_case=True))
|
||||
def ban_command_handler(message: types.Message):
|
||||
if len(message.text) == 4 and message.chat.type != 'private':
|
||||
try:
|
||||
bot.ban_chat_member(message.chat.id, message.reply_to_message.from_user.id)
|
||||
bot.reply_to(message.reply_to_message, 'Banned.')
|
||||
except Exception as err:
|
||||
print(err.args)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bot.add_custom_filter(TextMatchFilter())
|
||||
bot.add_custom_filter(IsReplyFilter())
|
||||
bot.infinity_polling()
|
@ -1,4 +1,8 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Union
|
||||
|
||||
from telebot import types
|
||||
|
||||
|
||||
class SimpleCustomFilter(ABC):
|
||||
"""
|
||||
@ -30,6 +34,101 @@ class AdvancedCustomFilter(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class TextFilter:
|
||||
"""
|
||||
Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||
|
||||
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[Union[list, tuple]] = None,
|
||||
starts_with: Optional[Union[str, list, tuple]] = None,
|
||||
ends_with: Optional[Union[str, list, tuple]] = None,
|
||||
ignore_case: bool = False):
|
||||
|
||||
"""
|
||||
:param equals: string, True if object's text is equal to passed string
|
||||
:param contains: list[str] or tuple[str], True if any string element of iterable is in text
|
||||
:param starts_with: string, True if object's text starts with passed string
|
||||
:param ends_with: string, True if object's text starts with passed string
|
||||
:param ignore_case: bool (default False), case insensitive
|
||||
"""
|
||||
|
||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||
if to_check == 0:
|
||||
raise ValueError('None of the check modes was specified')
|
||||
|
||||
self.equals = equals
|
||||
self.contains = self._check_iterable(contains, filter_name='contains')
|
||||
self.starts_with = self._check_iterable(starts_with, filter_name='starts_with')
|
||||
self.ends_with = self._check_iterable(ends_with, filter_name='ends_with')
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
def _check_iterable(self, iterable, filter_name):
|
||||
if not iterable:
|
||||
pass
|
||||
elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple):
|
||||
raise ValueError(f"Incorrect value of {filter_name!r}")
|
||||
elif isinstance(iterable, str):
|
||||
iterable = [iterable]
|
||||
elif isinstance(iterable, list) or isinstance(iterable, tuple):
|
||||
iterable = [i for i in iterable if isinstance(i, str)]
|
||||
return iterable
|
||||
|
||||
async def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||
|
||||
if isinstance(obj, types.Poll):
|
||||
text = obj.question
|
||||
elif isinstance(obj, types.Message):
|
||||
text = obj.text or obj.caption
|
||||
elif isinstance(obj, types.CallbackQuery):
|
||||
text = obj.data
|
||||
elif isinstance(obj, types.InlineQuery):
|
||||
text = obj.query
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.equals:
|
||||
self.equals = self.equals.lower()
|
||||
elif self.contains:
|
||||
self.contains = tuple(map(str.lower, self.contains))
|
||||
elif self.starts_with:
|
||||
self.starts_with = tuple(map(str.lower, self.starts_with))
|
||||
elif self.ends_with:
|
||||
self.ends_with = tuple(map(str.lower, self.ends_with))
|
||||
|
||||
if self.equals:
|
||||
result = self.equals == text
|
||||
if result:
|
||||
return True
|
||||
elif not result and not any((self.contains, self.starts_with, self.ends_with)):
|
||||
return False
|
||||
|
||||
if self.contains:
|
||||
result = any([i in text for i in self.contains])
|
||||
if result:
|
||||
return True
|
||||
elif not result and not any((self.starts_with, self.ends_with)):
|
||||
return False
|
||||
|
||||
if self.starts_with:
|
||||
result = any([text.startswith(i) for i in self.starts_with])
|
||||
if result:
|
||||
return True
|
||||
elif not result and not self.ends_with:
|
||||
return False
|
||||
|
||||
if self.ends_with:
|
||||
return any([text.endswith(i) for i in self.ends_with])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TextMatchFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check Text message.
|
||||
@ -42,8 +141,13 @@ class TextMatchFilter(AdvancedCustomFilter):
|
||||
key = 'text'
|
||||
|
||||
async def check(self, message, text):
|
||||
if type(text) is list:return message.text in text
|
||||
else: return text == message.text
|
||||
if isinstance(text, TextFilter):
|
||||
return await text.check(message)
|
||||
elif type(text) is list:
|
||||
return message.text in text
|
||||
else:
|
||||
return text == message.text
|
||||
|
||||
|
||||
class TextContainsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -58,7 +162,15 @@ class TextContainsFilter(AdvancedCustomFilter):
|
||||
key = 'text_contains'
|
||||
|
||||
async def check(self, message, text):
|
||||
return text in message.text
|
||||
if not isinstance(text, str) and not isinstance(text, list) and not isinstance(text, tuple):
|
||||
raise ValueError("Incorrect text_contains value")
|
||||
elif isinstance(text, str):
|
||||
text = [text]
|
||||
elif isinstance(text, list) or isinstance(text, tuple):
|
||||
text = [i for i in text if isinstance(i, str)]
|
||||
|
||||
return any([i in message.text for i in text])
|
||||
|
||||
|
||||
class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -70,9 +182,11 @@ class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'text_startswith'
|
||||
|
||||
async def check(self, message, text):
|
||||
return message.text.startswith(text)
|
||||
|
||||
|
||||
class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Check whether chat_id corresponds to given chat_id.
|
||||
@ -82,9 +196,11 @@ class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'chat_id'
|
||||
|
||||
async def check(self, message, text):
|
||||
return message.chat.id in text
|
||||
|
||||
|
||||
class ForwardFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message was forwarded from channel or group.
|
||||
@ -99,6 +215,7 @@ class ForwardFilter(SimpleCustomFilter):
|
||||
async def check(self, message):
|
||||
return message.forward_from_chat is not None
|
||||
|
||||
|
||||
class IsReplyFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message is a reply.
|
||||
@ -114,7 +231,6 @@ class IsReplyFilter(SimpleCustomFilter):
|
||||
return message.reply_to_message is not None
|
||||
|
||||
|
||||
|
||||
class LanguageFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Check users language_code.
|
||||
@ -127,8 +243,11 @@ class LanguageFilter(AdvancedCustomFilter):
|
||||
key = 'language_code'
|
||||
|
||||
async def check(self, message, text):
|
||||
if type(text) is list:return message.from_user.language_code in text
|
||||
else: return message.from_user.language_code == text
|
||||
if type(text) is list:
|
||||
return message.from_user.language_code in text
|
||||
else:
|
||||
return message.from_user.language_code == text
|
||||
|
||||
|
||||
class IsAdminFilter(SimpleCustomFilter):
|
||||
"""
|
||||
@ -147,6 +266,7 @@ class IsAdminFilter(SimpleCustomFilter):
|
||||
result = await self._bot.get_chat_member(message.chat.id, message.from_user.id)
|
||||
return result.status in ['creator', 'administrator']
|
||||
|
||||
|
||||
class StateFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check state.
|
||||
@ -154,8 +274,10 @@ class StateFilter(AdvancedCustomFilter):
|
||||
Example:
|
||||
@bot.message_handler(state=1)
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
key = 'state'
|
||||
|
||||
async def check(self, message, text):
|
||||
@ -176,12 +298,13 @@ class StateFilter(AdvancedCustomFilter):
|
||||
|
||||
|
||||
else:
|
||||
user_state = await self.bot.current_states.get_state(message.chat.id,message.from_user.id)
|
||||
user_state = await self.bot.current_states.get_state(message.chat.id, message.from_user.id)
|
||||
if user_state == text:
|
||||
return True
|
||||
elif type(text) is list and user_state in text:
|
||||
return True
|
||||
|
||||
|
||||
class IsDigitFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Filter to check whether the string is made up of only digits.
|
||||
|
@ -1,4 +1,9 @@
|
||||
from abc import ABC
|
||||
from typing import Optional, Union
|
||||
|
||||
from telebot import types
|
||||
|
||||
|
||||
class SimpleCustomFilter(ABC):
|
||||
"""
|
||||
Simple Custom Filter base class.
|
||||
@ -29,6 +34,100 @@ class AdvancedCustomFilter(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class TextFilter:
|
||||
"""
|
||||
Advanced text filter to check (types.Message, types.CallbackQuery, types.InlineQuery, types.Poll)
|
||||
|
||||
example of usage is in examples/custom_filters/advanced_text_filter.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equals: Optional[str] = None,
|
||||
contains: Optional[Union[list, tuple]] = None,
|
||||
starts_with: Optional[Union[str, list, tuple]] = None,
|
||||
ends_with: Optional[Union[str, list, tuple]] = None,
|
||||
ignore_case: bool = False):
|
||||
|
||||
"""
|
||||
:param equals: string, True if object's text is equal to passed string
|
||||
:param contains: list[str] or tuple[str], True if any string element of iterable is in text
|
||||
:param starts_with: string, True if object's text starts with passed string
|
||||
:param ends_with: string, True if object's text starts with passed string
|
||||
:param ignore_case: bool (default False), case insensitive
|
||||
"""
|
||||
|
||||
to_check = sum((pattern is not None for pattern in (equals, contains, starts_with, ends_with)))
|
||||
if to_check == 0:
|
||||
raise ValueError('None of the check modes was specified')
|
||||
|
||||
self.equals = equals
|
||||
self.contains = self._check_iterable(contains, filter_name='contains')
|
||||
self.starts_with = self._check_iterable(starts_with, filter_name='starts_with')
|
||||
self.ends_with = self._check_iterable(ends_with, filter_name='ends_with')
|
||||
self.ignore_case = ignore_case
|
||||
|
||||
def _check_iterable(self, iterable, filter_name: str):
|
||||
if not iterable:
|
||||
pass
|
||||
elif not isinstance(iterable, str) and not isinstance(iterable, list) and not isinstance(iterable, tuple):
|
||||
raise ValueError(f"Incorrect value of {filter_name!r}")
|
||||
elif isinstance(iterable, str):
|
||||
iterable = [iterable]
|
||||
elif isinstance(iterable, list) or isinstance(iterable, tuple):
|
||||
iterable = [i for i in iterable if isinstance(i, str)]
|
||||
return iterable
|
||||
|
||||
def check(self, obj: Union[types.Message, types.CallbackQuery, types.InlineQuery, types.Poll]):
|
||||
|
||||
if isinstance(obj, types.Poll):
|
||||
text = obj.question
|
||||
elif isinstance(obj, types.Message):
|
||||
text = obj.text or obj.caption
|
||||
elif isinstance(obj, types.CallbackQuery):
|
||||
text = obj.data
|
||||
elif isinstance(obj, types.InlineQuery):
|
||||
text = obj.query
|
||||
else:
|
||||
return False
|
||||
|
||||
if self.ignore_case:
|
||||
text = text.lower()
|
||||
|
||||
if self.equals:
|
||||
self.equals = self.equals.lower()
|
||||
elif self.contains:
|
||||
self.contains = tuple(map(str.lower, self.contains))
|
||||
elif self.starts_with:
|
||||
self.starts_with = tuple(map(str.lower, self.starts_with))
|
||||
elif self.ends_with:
|
||||
self.ends_with = tuple(map(str.lower, self.ends_with))
|
||||
|
||||
if self.equals:
|
||||
result = self.equals == text
|
||||
if result:
|
||||
return True
|
||||
elif not result and not any((self.contains, self.starts_with, self.ends_with)):
|
||||
return False
|
||||
|
||||
if self.contains:
|
||||
result = any([i in text for i in self.contains])
|
||||
if result:
|
||||
return True
|
||||
elif not result and not any((self.starts_with, self.ends_with)):
|
||||
return False
|
||||
|
||||
if self.starts_with:
|
||||
result = any([text.startswith(i) for i in self.starts_with])
|
||||
if result:
|
||||
return True
|
||||
elif not result and not self.ends_with:
|
||||
return False
|
||||
|
||||
if self.ends_with:
|
||||
return any([text.endswith(i) for i in self.ends_with])
|
||||
|
||||
return False
|
||||
|
||||
class TextMatchFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check Text message.
|
||||
@ -41,8 +140,13 @@ class TextMatchFilter(AdvancedCustomFilter):
|
||||
key = 'text'
|
||||
|
||||
def check(self, message, text):
|
||||
if type(text) is list:return message.text in text
|
||||
else: return text == message.text
|
||||
if isinstance(text, TextFilter):
|
||||
return text.check(message)
|
||||
elif type(text) is list:
|
||||
return message.text in text
|
||||
else:
|
||||
return text == message.text
|
||||
|
||||
|
||||
class TextContainsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -57,7 +161,15 @@ class TextContainsFilter(AdvancedCustomFilter):
|
||||
key = 'text_contains'
|
||||
|
||||
def check(self, message, text):
|
||||
return text in message.text
|
||||
if not isinstance(text, str) and not isinstance(text, list) and not isinstance(text, tuple):
|
||||
raise ValueError("Incorrect text_contains value")
|
||||
elif isinstance(text, str):
|
||||
text = [text]
|
||||
elif isinstance(text, list) or isinstance(text, tuple):
|
||||
text = [i for i in text if isinstance(i, str)]
|
||||
|
||||
return any([i in message.text for i in text])
|
||||
|
||||
|
||||
class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
@ -69,9 +181,11 @@ class TextStartsFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'text_startswith'
|
||||
|
||||
def check(self, message, text):
|
||||
return message.text.startswith(text)
|
||||
|
||||
|
||||
class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Check whether chat_id corresponds to given chat_id.
|
||||
@ -81,9 +195,11 @@ class ChatFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
|
||||
key = 'chat_id'
|
||||
|
||||
def check(self, message, text):
|
||||
return message.chat.id in text
|
||||
|
||||
|
||||
class ForwardFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message was forwarded from channel or group.
|
||||
@ -98,6 +214,7 @@ class ForwardFilter(SimpleCustomFilter):
|
||||
def check(self, message):
|
||||
return message.forward_from_chat is not None
|
||||
|
||||
|
||||
class IsReplyFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Check whether message is a reply.
|
||||
@ -113,7 +230,6 @@ class IsReplyFilter(SimpleCustomFilter):
|
||||
return message.reply_to_message is not None
|
||||
|
||||
|
||||
|
||||
class LanguageFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Check users language_code.
|
||||
@ -126,8 +242,11 @@ class LanguageFilter(AdvancedCustomFilter):
|
||||
key = 'language_code'
|
||||
|
||||
def check(self, message, text):
|
||||
if type(text) is list:return message.from_user.language_code in text
|
||||
else: return message.from_user.language_code == text
|
||||
if type(text) is list:
|
||||
return message.from_user.language_code in text
|
||||
else:
|
||||
return message.from_user.language_code == text
|
||||
|
||||
|
||||
class IsAdminFilter(SimpleCustomFilter):
|
||||
"""
|
||||
@ -145,6 +264,7 @@ class IsAdminFilter(SimpleCustomFilter):
|
||||
def check(self, message):
|
||||
return self._bot.get_chat_member(message.chat.id, message.from_user.id).status in ['creator', 'administrator']
|
||||
|
||||
|
||||
class StateFilter(AdvancedCustomFilter):
|
||||
"""
|
||||
Filter to check state.
|
||||
@ -152,8 +272,10 @@ class StateFilter(AdvancedCustomFilter):
|
||||
Example:
|
||||
@bot.message_handler(state=1)
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
key = 'state'
|
||||
|
||||
def check(self, message, text):
|
||||
@ -173,11 +295,13 @@ class StateFilter(AdvancedCustomFilter):
|
||||
|
||||
|
||||
else:
|
||||
user_state = self.bot.current_states.get_state(message.chat.id,message.from_user.id)
|
||||
user_state = self.bot.current_states.get_state(message.chat.id, message.from_user.id)
|
||||
if user_state == text:
|
||||
return True
|
||||
elif type(text) is list and user_state in text:
|
||||
return True
|
||||
|
||||
|
||||
class IsDigitFilter(SimpleCustomFilter):
|
||||
"""
|
||||
Filter to check whether the string is made up of only digits.
|
||||
|
Loading…
Reference in New Issue
Block a user