From a03769adde662bbb20620d3fd2551d5046feeb5d Mon Sep 17 00:00:00 2001 From: pieter Date: Wed, 1 Jul 2015 22:16:13 +0200 Subject: [PATCH] Created _make_request and removed check_result in apihelper.py for efficiency and consistency improvements. Created JsonSerializable (previously Jsonable) and JsonDeserializable. All relevant classes now subclass JsonDeserializable to eliminate unneeded json -> string -> json conversions. --- telebot/__init__.py | 9 +++-- telebot/apihelper.py | 77 +++++++++++++++--------------------------- telebot/types.py | 79 ++++++++++++++++++++++++++++++-------------- 3 files changed, 85 insertions(+), 80 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 31bff89..6630f31 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -42,13 +42,12 @@ class TeleBot: self.last_update_id = 0 def get_update(self): - result = apihelper.get_updates(self.token, offset=(self.last_update_id + 1)) - updates = result['result'] + updates = apihelper.get_updates(self.token, offset=(self.last_update_id + 1)) new_messages = [] for update in updates: if update['update_id'] > self.last_update_id: self.last_update_id = update['update_id'] - msg = types.Message.de_json(json.dumps(update['message'])) + msg = types.Message.de_json(update['message']) new_messages.append(msg) if len(new_messages) > 0: @@ -93,7 +92,7 @@ class TeleBot: def get_me(self): result = apihelper.get_me(self.token) - return types.User.de_json(json.dumps(result['result'])) + return types.User.de_json(result) def get_user_profile_photos(self, user_id, offset=None, limit=None): """ @@ -105,7 +104,7 @@ class TeleBot: :return: """ result = apihelper.get_user_profile_photos(self.token, user_id, offset, limit) - return types.UserProfilePhotos.de_json(json.dumps(result['result'])) + return types.UserProfilePhotos.de_json(result) def send_message(self, chat_id, text, disable_web_page_preview=None, reply_to_message_id=None, reply_markup=None): """ diff --git a/telebot/apihelper.py b/telebot/apihelper.py index ab59b2a..4b5e0cd 100644 --- a/telebot/apihelper.py +++ b/telebot/apihelper.py @@ -6,12 +6,23 @@ import telebot from telebot import types +def _make_request(token, method_name, method='get', params=None, files=None): + request_url = telebot.API_URL + 'bot' + token + '/' + method_name + result = requests.request(method, request_url, params=params, files=files) + if result.status_code != 200: + raise ApiError(method_name + r' error.', result) + try: + result_json = result.json() + if not result_json['ok']: + raise ApiError(method_name, ' failed, result=' + result_json) + except: + raise ApiError(method_name + r' error.', result) + return result_json['result'] + + def get_me(token): - api_url = telebot.API_URL - method_url = r'getMe' - request_url = api_url + 'bot' + token + '/' + method_url - req = requests.get(request_url) - return check_result(method_url, req) + method_url = 'getMe' + return _make_request(token, method_url) def send_message(token, chat_id, text, disable_web_page_preview=None, reply_to_message_id=None, reply_markup=None): @@ -25,9 +36,7 @@ def send_message(token, chat_id, text, disable_web_page_preview=None, reply_to_m :param reply_markup: :return: """ - api_url = telebot.API_URL method_url = r'sendMessage' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': str(chat_id), 'text': text} if disable_web_page_preview: payload['disable_web_page_preview'] = disable_web_page_preview @@ -35,46 +44,34 @@ def send_message(token, chat_id, text, disable_web_page_preview=None, reply_to_m payload['reply_to_message_id'] = reply_to_message_id if reply_markup: payload['reply_markup'] = convert_markup(reply_markup) - req = requests.get(request_url, params=payload) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload) def get_updates(token, offset=None): - api_url = telebot.API_URL method_url = r'getUpdates' if offset is not None: - request_url = api_url + 'bot' + token + '/' + method_url + '?offset=' + str(offset) + return _make_request(token, method_url, params={'offset': offset}) else: - request_url = api_url + 'bot' + token + '/' + method_url - req = requests.get(request_url) - return check_result(method_url, req) + return _make_request(token, method_url) def get_user_profile_photos(token, user_id, offset=None, limit=None): - api_url = telebot.API_URL method_url = r'getUserProfilePhotos' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'user_id': user_id} if offset: payload['offset'] = offset if limit: payload['limit'] = limit - req = requests.get(request_url, params=payload) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload) def forward_message(token, chat_id, from_chat_id, message_id): - api_url = telebot.API_URL method_url = r'forwardMessage' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': chat_id, 'from_chat_id': from_chat_id, 'message_id': message_id} - req = requests.get(request_url, params=payload) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload) def send_photo(token, chat_id, photo, caption=None, reply_to_message_id=None, reply_markup=None): - api_url = telebot.API_URL method_url = r'sendPhoto' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': chat_id} files = {'photo': photo} if caption: @@ -83,44 +80,34 @@ def send_photo(token, chat_id, photo, caption=None, reply_to_message_id=None, re payload['reply_to_message_id'] = reply_to_message_id if reply_markup: payload['reply_markup'] = convert_markup(reply_markup) - req = requests.post(request_url, params=payload, files=files) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload, files=files, method='post') def send_location(token, chat_id, latitude, longitude, reply_to_message_id=None, reply_markup=None): - api_url = telebot.API_URL method_url = r'sendLocation' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': chat_id, 'latitude': latitude, 'longitude': longitude} if reply_to_message_id: payload['reply_to_message_id'] = reply_to_message_id if reply_markup: payload['reply_markup'] = convert_markup(reply_markup) - req = requests.get(request_url, params=payload) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload) def send_chat_action(token, chat_id, action): - api_url = telebot.API_URL method_url = r'sendChatAction' - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': chat_id, 'action': action} - req = requests.get(request_url, params=payload) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload) def send_data(token, chat_id, data, data_type, reply_to_message_id=None, reply_markup=None): - api_url = telebot.API_URL method_url = get_method_by_type(data_type) - request_url = api_url + 'bot' + token + '/' + method_url payload = {'chat_id': chat_id} files = {data_type: data} if reply_to_message_id: payload['reply_to_message_id'] = reply_to_message_id if reply_markup: payload['reply_markup'] = convert_markup(reply_markup) - req = requests.post(request_url, params=payload, files=files) - return check_result(method_url, req) + return _make_request(token, method_url, params=payload, files=files, method='post') def get_method_by_type(data_type): @@ -134,20 +121,8 @@ def get_method_by_type(data_type): return 'sendVideo' -def check_result(func_name, result): - if result.status_code != 200: - raise ApiError(func_name + r' error.', result) - try: - result_json = result.json() - if not result_json['ok']: - raise Exception(func_name, ' failed, result=' + result_json) - except: - raise ApiError(func_name + r' error.', result) - return result_json - - def convert_markup(markup): - if isinstance(markup, types.Jsonable): + if isinstance(markup, types.JsonSerializable): return markup.to_json() return markup diff --git a/telebot/types.py b/telebot/types.py index 59727cf..a2d5ce7 100644 --- a/telebot/types.py +++ b/telebot/types.py @@ -22,7 +22,7 @@ ForceReply import json -class Jsonable: +class JsonSerializable: """ Subclasses of this class are guaranteed to be able to be converted to JSON format. All subclasses of this class must override to_json. @@ -36,10 +36,41 @@ class Jsonable: """ raise NotImplementedError -class User: +class JsonDeserializable: + """ + Subclasses of this class are guaranteed to be able to be created from a json-style dict or json formatted string. + All subclasses of this class must override de_json. + """ + @classmethod + def de_json(self, json_type): + """ + Returns an instance of this class from the given json dict or string. + + This function must be overridden by subclasses. + :return: an instance of this class created from the given json dict or string. + """ + raise NotImplementedError + + @staticmethod + def check_json(self, json_type): + """ + Checks whether json_type is a dict or a string. If it is already a dict, it is returned as-is. + If it is not, it is converted to a dict by means of json.loads(json_type) + :param json_type: + :return: + """ + if type(json_type) == dict: + return json_type + elif type(json_type) == str: + return json.loads(json_type) + else: + raise ValueError("json_type should be a json dict or string.") + + +class User(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) id = obj['id'] first_name = obj['first_name'] last_name = None @@ -57,10 +88,10 @@ class User: self.last_name = last_name -class GroupChat: +class GroupChat(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) id = obj['id'] title = obj['title'] return GroupChat(id, title) @@ -70,10 +101,10 @@ class GroupChat: self.title = title -class Message: +class Message(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) message_id = obj['message_id'] from_user = User.de_json(json.dumps(obj['from'])) chat = Message.parse_chat(obj['chat']) @@ -127,10 +158,10 @@ class Message: setattr(self, key, options[key]) -class PhotoSize: +class PhotoSize(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) file_id = obj['file_id'] width = obj['width'] height = obj['height'] @@ -146,10 +177,10 @@ class PhotoSize: self.file_id = file_id -class Audio: +class Audio(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) file_id = obj['file_id'] duration = obj['duration'] mime_type = None @@ -167,10 +198,10 @@ class Audio: self.file_size = file_size -class Document: +class Document(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) file_id = obj['file_id'] thumb = None if 'file_id' in obj['thumb']: @@ -194,10 +225,10 @@ class Document: self.file_size = file_size -class Sticker: +class Sticker(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) file_id = obj['file_id'] width = obj['width'] height = obj['height'] @@ -215,10 +246,10 @@ class Sticker: self.file_size = file_size -class Video: +class Video(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) file_id = obj['file_id'] width = obj['width'] height = obj['height'] @@ -255,10 +286,10 @@ class Contact: self.user_id = user_id -class Location: +class Location(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) longitude = obj['longitude'] latitude = obj['latitude'] return Location(longitude, latitude) @@ -268,10 +299,10 @@ class Location: self.latitude = latitude -class UserProfilePhotos: +class UserProfilePhotos(JsonDeserializable): @classmethod def de_json(cls, json_string): - obj = json.loads(json_string) + obj = cls.check_json(json_string) total_count = obj['total_count'] photos = [[PhotoSize.de_json(json.dumps(y)) for y in x] for x in obj['photos']] return UserProfilePhotos(total_count, photos) @@ -281,7 +312,7 @@ class UserProfilePhotos: self.photos = photos -class ForceReply(Jsonable): +class ForceReply(JsonSerializable): def __init__(self, selective=None): self.selective = selective @@ -292,7 +323,7 @@ class ForceReply(Jsonable): return json.dumps(json_dict) -class ReplyKeyboardHide(Jsonable): +class ReplyKeyboardHide(JsonSerializable): def __init__(self, selective=None): self.selective = selective @@ -303,7 +334,7 @@ class ReplyKeyboardHide(Jsonable): return json.dumps(json_dict) -class ReplyKeyboardMarkup(Jsonable): +class ReplyKeyboardMarkup(JsonSerializable): def __init__(self, resize_keyboard=None, one_time_keyboard=None, selective=None, row_width=3): self.resize_keyboard = resize_keyboard self.one_time_keyboard = one_time_keyboard