1
0
mirror of https://github.com/eternnoir/pyTelegramBotAPI.git synced 2023-08-10 21:12:57 +03:00

Better error handling.

Errors now are re-raised in the Thread polling() was called from.
If none_stop is *not* set, ApiExceptions will cause the calling Thread to halt.
This commit is contained in:
pieter 2015-10-01 22:03:54 +02:00
parent 60ca1751ca
commit d14e9051d4
2 changed files with 134 additions and 58 deletions

View File

@ -10,7 +10,9 @@ import six
import logging import logging
logger = logging.getLogger('TeleBot') logger = logging.getLogger('TeleBot')
formatter = logging.Formatter('%(asctime)s (%(filename)s:%(lineno)d) %(levelname)s - %(name)s: "%(message)s"') formatter = logging.Formatter(
'%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"'
)
console_output_handler = logging.StreamHandler(sys.stderr) console_output_handler = logging.StreamHandler(sys.stderr)
console_output_handler.setFormatter(formatter) console_output_handler.setFormatter(formatter)
@ -42,20 +44,17 @@ class TeleBot:
getUpdates getUpdates
""" """
def __init__(self, token, create_threads=True, num_threads=4): def __init__(self, token):
""" """
:param token: bot API token :param token: bot API token
:param create_threads: Create thread for message handler
:param num_threads: Number of worker in thread pool.
:return: Telebot object. :return: Telebot object.
""" """
self.token = token self.token = token
self.update_listener = [] self.update_listener = []
self.polling_thread = None
self.__stop_polling = threading.Event() self.__stop_polling = threading.Event()
self.last_update_id = 0 self.last_update_id = 0
self.num_threads = num_threads self.exc_info = None
self.__create_threads = create_threads
self.message_subscribers_messages = [] self.message_subscribers_messages = []
self.message_subscribers_callbacks = [] self.message_subscribers_callbacks = []
@ -65,8 +64,7 @@ class TeleBot:
self.message_subscribers_next_step = {} self.message_subscribers_next_step = {}
self.message_handlers = [] self.message_handlers = []
if self.__create_threads: self.worker_pool = util.ThreadPool()
self.worker_pool = util.ThreadPool(num_threads)
def set_webhook(self, url=None, certificate=None): def set_webhook(self, url=None, certificate=None):
return apihelper.set_webhook(self.token, url, certificate) return apihelper.set_webhook(self.token, url, certificate)
@ -88,12 +86,13 @@ class TeleBot:
ret.append(types.Update.de_json(ju)) ret.append(types.Update.de_json(ju))
return ret return ret
def get_update(self, timeout=20): def __retrieve_updates(self, timeout=20):
""" """
Retrieves any updates from the Telegram API. Retrieves any updates from the Telegram API.
Registered listeners and applicable message handlers will be notified when a new message arrives. Registered listeners and applicable message handlers will be notified when a new message arrives.
:raises ApiException when a call has failed. :raises ApiException when a call has failed.
""" """
raise apihelper.ApiException("Test2", None, None)
updates = self.get_updates(offset=(self.last_update_id + 1), timeout=timeout) updates = self.get_updates(offset=(self.last_update_id + 1), timeout=timeout)
new_messages = [] new_messages = []
for update in updates: for update in updates:
@ -112,57 +111,52 @@ class TeleBot:
def __notify_update(self, new_messages): def __notify_update(self, new_messages):
for listener in self.update_listener: for listener in self.update_listener:
if self.__create_threads:
self.worker_pool.put(listener, new_messages) self.worker_pool.put(listener, new_messages)
else:
listener(new_messages)
def polling(self, none_stop=False, interval=0, block=True, timeout=20): def polling(self, none_stop=False, interval=0, timeout=20):
""" """
This function creates a new Thread that calls an internal __polling function. This function creates a new Thread that calls an internal __retrieve_updates function.
This allows the bot to retrieve Updates automagically and notify listeners and message handlers accordingly. This allows the bot to retrieve Updates automagically and notify listeners and message handlers accordingly.
Do not call this function more than once! Warning: Do not call this function more than once!
Always get updates. Always get updates.
:param none_stop: Do not stop polling when Exception occur. :param none_stop: Do not stop polling when an ApiException occurs.
:param timeout: Timeout in seconds for long polling. :param timeout: Timeout in seconds for long polling.
:return: :return:
""" """
self.__stop_polling.set()
if self.polling_thread:
self.polling_thread.join() # wait thread stop.
self.__stop_polling.clear()
self.polling_thread = threading.Thread(target=self.__polling, args=([none_stop, interval, timeout]))
self.polling_thread.daemon = True
self.polling_thread.start()
if block:
while self.polling_thread.is_alive:
try:
time.sleep(.1)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt. Stopping.")
self.stop_polling()
self.polling_thread.join()
break
def __polling(self, none_stop, interval, timeout):
logger.info('Started polling.') logger.info('Started polling.')
error_interval = .25 error_interval = .25
polling_thread = util.WorkerThread(name="PollingThread")
or_event = util.OrEvent(
polling_thread.done_event,
polling_thread.exception_event,
self.worker_pool.exception_event
)
while not self.__stop_polling.wait(interval): while not self.__stop_polling.wait(interval):
or_event.clear()
try: try:
self.get_update(timeout) polling_thread.put(self.__retrieve_updates, timeout)
or_event.wait()
polling_thread.raise_exceptions()
self.worker_pool.raise_exceptions()
error_interval = .25 error_interval = .25
except apihelper.ApiException as e: except apihelper.ApiException as e:
logger.error(e)
if not none_stop: if not none_stop:
self.__stop_polling.set() self.__stop_polling.set()
logger.info("Exception occurred. Stopping.") logger.info("Exception occurred. Stopping.")
else: else:
polling_thread.clear_exceptions()
self.worker_pool.clear_exceptions()
logger.info("Waiting for {0} seconds until retry".format(error_interval))
time.sleep(error_interval) time.sleep(error_interval)
error_interval *= 2 error_interval *= 2
logger.error(e)
logger.info('Stopped polling.') logger.info('Stopped polling.')
@ -459,10 +453,7 @@ class TeleBot:
for message in new_messages: for message in new_messages:
for message_handler in self.message_handlers: for message_handler in self.message_handlers:
if self._test_message_handler(message_handler, message): if self._test_message_handler(message_handler, message):
if self.__create_threads:
self.worker_pool.put(message_handler['function'], message) self.worker_pool.put(message_handler['function'], message)
else:
message_handler['function'](message)
break break

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import threading import threading
import sys
from six import string_types from six import string_types
# Python3 queue support. # Python3 queue support.
@ -9,45 +10,97 @@ try:
except ImportError: except ImportError:
import queue as Queue import queue as Queue
from apihelper import ApiException
from telebot import logger from telebot import logger
class ThreadPool:
class WorkerThread(threading.Thread): class WorkerThread(threading.Thread):
count = 0 count = 0
def __init__(self, queue): def __init__(self, exception_callback=None, queue=Queue.Queue(), name=None):
threading.Thread.__init__(self, name="WorkerThread{0}".format(self.__class__.count + 1)) if not name:
name = "WorkerThread{0}".format(self.__class__.count + 1)
self.__class__.count += 1 self.__class__.count += 1
threading.Thread.__init__(self, name=name)
self.queue = queue self.queue = queue
self.daemon = True self.daemon = True
self.received_task_event = threading.Event()
self.done_event = threading.Event()
self.exception_event = threading.Event()
self.continue_event = threading.Event()
self.exception_callback = exception_callback
self.exc_info = None
self._running = True self._running = True
self.start() self.start()
def run(self): def run(self):
while self._running: while self._running:
self.continue_event.clear()
self.received_task_event.clear()
self.done_event.clear()
self.exception_event.clear()
try: try:
task, args, kwargs = self.queue.get(block=True, timeout=.01) task, args, kwargs = self.queue.get(block=True, timeout=.01)
logger.debug("Received task")
self.received_task_event.set()
task(*args, **kwargs) task(*args, **kwargs)
logger.debug("Task complete")
self.done_event.set()
except Queue.Empty: except Queue.Empty:
pass pass
except ApiException as e: except:
logger.exception(e) logger.debug("Exception occurred")
self.exc_info = sys.exc_info()
self.exception_event.set()
if self.exception_callback:
self.exception_callback(self, self.exc_info)
self.continue_event.wait()
def put(self, task, *args, **kwargs):
self.queue.put((task, args, kwargs))
def raise_exceptions(self):
if self.exception_event.is_set():
raise self.exc_info[0], self.exc_info[1], self.exc_info[2]
def clear_exceptions(self):
self.exception_event.clear()
self.continue_event.set()
def stop(self): def stop(self):
self._running = False self._running = False
def __init__(self, num_threads=4):
self.tasks = Queue.Queue()
self.workers = [self.WorkerThread(self.tasks) for _ in range(num_threads)]
class ThreadPool:
def __init__(self, num_threads=2):
self.tasks = Queue.Queue()
self.workers = [WorkerThread(self.on_exception, self.tasks) for _ in range(num_threads)]
self.num_threads = num_threads self.num_threads = num_threads
self.exception_event = threading.Event()
self.exc_info = None
def put(self, func, *args, **kwargs): def put(self, func, *args, **kwargs):
self.tasks.put((func, args, kwargs)) self.tasks.put((func, args, kwargs))
def on_exception(self, worker_thread, exc_info):
self.exc_info = exc_info
self.exception_event.set()
worker_thread.continue_event.set()
def raise_exceptions(self):
if self.exception_event.is_set():
raise self.exc_info[0], self.exc_info[1], self.exc_info[2]
def clear_exceptions(self):
self.exception_event.clear()
def close(self): def close(self):
for worker in self.workers: for worker in self.workers:
worker.stop() worker.stop()
@ -68,15 +121,15 @@ class AsyncTask:
def _run(self): def _run(self):
try: try:
self.result = self.target(*self.args, **self.kwargs) self.result = self.target(*self.args, **self.kwargs)
except Exception as e: except:
self.result = e self.result = sys.exc_info()
self.done = True self.done = True
def wait(self): def wait(self):
if not self.done: if not self.done:
self.thread.join() self.thread.join()
if isinstance(self.result, Exception): if isinstance(self.result, BaseException):
raise self.result raise self.result[0], self.result[1], self.result[2]
else: else:
return self.result return self.result
@ -130,3 +183,35 @@ def split_string(text, chars_per_string):
:return: The splitted text as a list of strings. :return: The splitted text as a list of strings.
""" """
return [text[i:i + chars_per_string] for i in range(0, len(text), chars_per_string)] return [text[i:i + chars_per_string] for i in range(0, len(text), chars_per_string)]
# CREDITS TO http://stackoverflow.com/questions/12317940#answer-12320352
def or_set(self):
self._set()
self.changed()
def or_clear(self):
self._clear()
self.changed()
def orify(e, changed_callback):
e._set = e.set
e._clear = e.clear
e.changed = changed_callback
e.set = lambda: or_set(e)
e.clear = lambda: or_clear(e)
def OrEvent(*events):
or_event = threading.Event()
def changed():
bools = [e.is_set() for e in events]
if any(bools):
or_event.set()
else:
or_event.clear()
for e in events:
orify(e, changed)
changed()
return or_event