diff --git a/telebot/__init__.py b/telebot/__init__.py index fefd164..9e23873 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -10,7 +10,9 @@ import six import logging logger = logging.getLogger('TeleBot') -formatter = logging.Formatter('%(asctime)s (%(filename)s:%(lineno)d) %(levelname)s - %(name)s: "%(message)s"') +formatter = logging.Formatter( + '%(asctime)s (%(filename)s:%(lineno)d %(threadName)s) %(levelname)s - %(name)s: "%(message)s"' +) console_output_handler = logging.StreamHandler(sys.stderr) console_output_handler.setFormatter(formatter) @@ -42,20 +44,17 @@ class TeleBot: getUpdates """ - def __init__(self, token, create_threads=True, num_threads=4): + def __init__(self, token): """ :param token: bot API token - :param create_threads: Create thread for message handler - :param num_threads: Number of worker in thread pool. :return: Telebot object. """ self.token = token self.update_listener = [] - self.polling_thread = None + self.__stop_polling = threading.Event() self.last_update_id = 0 - self.num_threads = num_threads - self.__create_threads = create_threads + self.exc_info = None self.message_subscribers_messages = [] self.message_subscribers_callbacks = [] @@ -65,8 +64,7 @@ class TeleBot: self.message_subscribers_next_step = {} self.message_handlers = [] - if self.__create_threads: - self.worker_pool = util.ThreadPool(num_threads) + self.worker_pool = util.ThreadPool() def set_webhook(self, url=None, certificate=None): return apihelper.set_webhook(self.token, url, certificate) @@ -88,12 +86,13 @@ class TeleBot: ret.append(types.Update.de_json(ju)) return ret - def get_update(self, timeout=20): + def __retrieve_updates(self, timeout=20): """ Retrieves any updates from the Telegram API. Registered listeners and applicable message handlers will be notified when a new message arrives. :raises ApiException when a call has failed. """ + raise apihelper.ApiException("Test2", None, None) updates = self.get_updates(offset=(self.last_update_id + 1), timeout=timeout) new_messages = [] for update in updates: @@ -112,57 +111,52 @@ class TeleBot: def __notify_update(self, new_messages): for listener in self.update_listener: - if self.__create_threads: - self.worker_pool.put(listener, new_messages) - else: - listener(new_messages) + self.worker_pool.put(listener, new_messages) - def polling(self, none_stop=False, interval=0, block=True, timeout=20): + def polling(self, none_stop=False, interval=0, timeout=20): """ - This function creates a new Thread that calls an internal __polling function. + This function creates a new Thread that calls an internal __retrieve_updates function. This allows the bot to retrieve Updates automagically and notify listeners and message handlers accordingly. - Do not call this function more than once! + Warning: Do not call this function more than once! Always get updates. - :param none_stop: Do not stop polling when Exception occur. + :param none_stop: Do not stop polling when an ApiException occurs. :param timeout: Timeout in seconds for long polling. :return: """ - self.__stop_polling.set() - if self.polling_thread: - self.polling_thread.join() # wait thread stop. - self.__stop_polling.clear() - self.polling_thread = threading.Thread(target=self.__polling, args=([none_stop, interval, timeout])) - self.polling_thread.daemon = True - self.polling_thread.start() - - if block: - while self.polling_thread.is_alive: - try: - time.sleep(.1) - except KeyboardInterrupt: - logger.info("Received KeyboardInterrupt. Stopping.") - self.stop_polling() - self.polling_thread.join() - break - - def __polling(self, none_stop, interval, timeout): logger.info('Started polling.') error_interval = .25 + + polling_thread = util.WorkerThread(name="PollingThread") + or_event = util.OrEvent( + polling_thread.done_event, + polling_thread.exception_event, + self.worker_pool.exception_event + ) + while not self.__stop_polling.wait(interval): + or_event.clear() try: - self.get_update(timeout) + polling_thread.put(self.__retrieve_updates, timeout) + or_event.wait() + + polling_thread.raise_exceptions() + self.worker_pool.raise_exceptions() + error_interval = .25 except apihelper.ApiException as e: + logger.error(e) if not none_stop: self.__stop_polling.set() logger.info("Exception occurred. Stopping.") else: + polling_thread.clear_exceptions() + self.worker_pool.clear_exceptions() + logger.info("Waiting for {0} seconds until retry".format(error_interval)) time.sleep(error_interval) error_interval *= 2 - logger.error(e) logger.info('Stopped polling.') @@ -459,10 +453,7 @@ class TeleBot: for message in new_messages: for message_handler in self.message_handlers: if self._test_message_handler(message_handler, message): - if self.__create_threads: - self.worker_pool.put(message_handler['function'], message) - else: - message_handler['function'](message) + self.worker_pool.put(message_handler['function'], message) break diff --git a/telebot/util.py b/telebot/util.py index 168ad58..5b72bf1 100644 --- a/telebot/util.py +++ b/telebot/util.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import threading +import sys from six import string_types # Python3 queue support. @@ -9,45 +10,97 @@ try: except ImportError: import queue as Queue -from apihelper import ApiException from telebot import logger -class ThreadPool: - class WorkerThread(threading.Thread): +class WorkerThread(threading.Thread): count = 0 - def __init__(self, queue): - threading.Thread.__init__(self, name="WorkerThread{0}".format(self.__class__.count + 1)) - self.__class__.count += 1 + def __init__(self, exception_callback=None, queue=Queue.Queue(), name=None): + if not name: + name = "WorkerThread{0}".format(self.__class__.count + 1) + self.__class__.count += 1 + + threading.Thread.__init__(self, name=name) self.queue = queue self.daemon = True + self.received_task_event = threading.Event() + self.done_event = threading.Event() + self.exception_event = threading.Event() + self.continue_event = threading.Event() + + self.exception_callback = exception_callback + self.exc_info = None self._running = True self.start() def run(self): while self._running: + self.continue_event.clear() + self.received_task_event.clear() + self.done_event.clear() + self.exception_event.clear() + try: task, args, kwargs = self.queue.get(block=True, timeout=.01) + logger.debug("Received task") + self.received_task_event.set() + task(*args, **kwargs) + logger.debug("Task complete") + self.done_event.set() except Queue.Empty: pass - except ApiException as e: - logger.exception(e) + except: + logger.debug("Exception occurred") + self.exc_info = sys.exc_info() + self.exception_event.set() + + if self.exception_callback: + self.exception_callback(self, self.exc_info) + self.continue_event.wait() + + def put(self, task, *args, **kwargs): + self.queue.put((task, args, kwargs)) + + def raise_exceptions(self): + if self.exception_event.is_set(): + raise self.exc_info[0], self.exc_info[1], self.exc_info[2] + + def clear_exceptions(self): + self.exception_event.clear() + self.continue_event.set() def stop(self): self._running = False - def __init__(self, num_threads=4): - self.tasks = Queue.Queue() - self.workers = [self.WorkerThread(self.tasks) for _ in range(num_threads)] +class ThreadPool: + + def __init__(self, num_threads=2): + self.tasks = Queue.Queue() + self.workers = [WorkerThread(self.on_exception, self.tasks) for _ in range(num_threads)] self.num_threads = num_threads + self.exception_event = threading.Event() + self.exc_info = None + def put(self, func, *args, **kwargs): self.tasks.put((func, args, kwargs)) + def on_exception(self, worker_thread, exc_info): + self.exc_info = exc_info + self.exception_event.set() + worker_thread.continue_event.set() + + def raise_exceptions(self): + if self.exception_event.is_set(): + raise self.exc_info[0], self.exc_info[1], self.exc_info[2] + + def clear_exceptions(self): + self.exception_event.clear() + def close(self): for worker in self.workers: worker.stop() @@ -68,15 +121,15 @@ class AsyncTask: def _run(self): try: self.result = self.target(*self.args, **self.kwargs) - except Exception as e: - self.result = e + except: + self.result = sys.exc_info() self.done = True def wait(self): if not self.done: self.thread.join() - if isinstance(self.result, Exception): - raise self.result + if isinstance(self.result, BaseException): + raise self.result[0], self.result[1], self.result[2] else: return self.result @@ -130,3 +183,35 @@ def split_string(text, chars_per_string): :return: The splitted text as a list of strings. """ return [text[i:i + chars_per_string] for i in range(0, len(text), chars_per_string)] + +# CREDITS TO http://stackoverflow.com/questions/12317940#answer-12320352 +def or_set(self): + self._set() + self.changed() + + +def or_clear(self): + self._clear() + self.changed() + + +def orify(e, changed_callback): + e._set = e.set + e._clear = e.clear + e.changed = changed_callback + e.set = lambda: or_set(e) + e.clear = lambda: or_clear(e) + + +def OrEvent(*events): + or_event = threading.Event() + def changed(): + bools = [e.is_set() for e in events] + if any(bools): + or_event.set() + else: + or_event.clear() + for e in events: + orify(e, changed) + changed() + return or_event