Compare commits

..

76 Commits
4.0.1 ... 6.0.6

Author SHA1 Message Date
050b14fb53 v6.0.6 2016-03-06 14:14:40 -08:00
c7efc33463 upgrade wakatime-cli to v4.1.13 2016-03-06 14:13:27 -08:00
d0ddbed006 v6.0.5 2016-03-06 12:49:49 -08:00
3ce8f388ab upgrade wakatime-cli to v4.1.11 2016-03-06 12:48:42 -08:00
90731146f9 add unresponsive plugin fix 2016-02-09 07:39:50 -08:00
e1ab92be6d v6.0.4 2016-01-15 12:27:52 +01:00
8b59e46c64 force as str before decoding as unicode 2016-01-15 12:24:43 +01:00
006341eb72 Merge pull request #61 from real666maverick/bug_UnicodeDecodeError
fix UnicodeDecodeError: 'ascii' codec can't decode byte 0xd0 in posiion 10: ordinal not in range(128) on plugin_loaded (ST2)
2016-01-15 12:18:45 +01:00
b54e0e13f6 fix UnicodeDecodeError: 'ascii' codec can't decode byte 0xd0 in position 10: ordinal not in range(128) on plugin_loaded (ST2) 2016-01-15 12:21:18 +03:00
835c7db864 v6.0.3 2016-01-11 11:29:52 -08:00
53e8bb04e9 upgrade wakatime-cli to v4.1.10 2016-01-11 11:27:56 -08:00
4aa06e3829 v6.0.2 2016-01-06 13:59:56 -08:00
297f65733f upgrade wakatime-cli core to v4.1.9 2016-01-06 13:57:03 -08:00
5ba5e6d21b update plugin description 2016-01-03 17:47:56 -08:00
32eadda81f v6.0.1 2016-01-01 23:09:08 -08:00
c537044801 log output from wakatime-cli when in debug mode 2016-01-01 23:06:03 -08:00
a97792c23c make sure python --version outputs a version number 2016-01-01 22:20:02 -08:00
4223f3575f v6.0.0 2015-12-01 01:06:00 -08:00
284cdf3ce4 update log messages for embedded python 2015-12-01 01:04:41 -08:00
27afc41bf4 remove python zip file after extracting 2015-12-01 01:00:48 -08:00
1fdda0d64a use embeddable python on Windows instead of installing python 2015-12-01 00:51:09 -08:00
c90a4863e9 v5.0.1 2015-10-06 04:04:15 -07:00
94343e5b07 look for python in system PATH again 2015-10-06 03:59:04 -07:00
03acea6e25 v5.0.0 2015-10-02 12:11:04 -07:00
77594700bd switch registry warnings to debug log level 2015-10-02 12:09:37 -07:00
6681409e98 v4.0.20 2015-10-01 15:20:34 -07:00
8f7837269a correctly find python binary in non-Windows environments 2015-10-01 15:19:58 -07:00
a523b3aa4d v4.0.19 2015-10-01 15:15:15 -07:00
6985ce32bb handle case where Sublime Text does not have winreg builtin module 2015-10-01 15:14:36 -07:00
4be40c7720 v4.0.18 2015-10-01 15:08:35 -07:00
eeb7fd8219 find python binary location from windows registry 2015-10-01 15:07:55 -07:00
11fbd2d2a6 v4.0.17 2015-10-01 12:42:04 -07:00
3cecd0de5d improve C# and php dependency detection 2015-10-01 12:41:50 -07:00
c50100e675 download and install python in background thread to fix #53 2015-10-01 12:39:28 -07:00
c1da94bc18 better looking obfuscated api key 2015-10-01 11:24:13 -07:00
7f9d6ede9d v4.0.16 2015-09-29 03:13:04 -07:00
192a5c7aa7 upgrade wakatime cli to v4.1.8 2015-09-29 03:11:25 -07:00
16bbe21be9 update project description 2015-09-26 11:51:19 -07:00
5ebaf12a99 v4.0.15 2015-08-28 15:06:49 -07:00
1834e8978a upgrade wakatime cli to v4.1.3 2015-08-28 15:05:55 -07:00
22c8ed74bd v4.0.14 2015-08-25 11:20:45 -07:00
12bbb4e561 upgrade wakatime cli to v4.1.2 2015-08-25 11:20:12 -07:00
c71cb21cc1 remove extra line 2015-08-25 00:52:40 -07:00
eb11b991f0 v4.0.13 2015-08-25 00:43:56 -07:00
7ea51d09ba upgrade wakatime cli to v4.1.1 2015-08-25 00:42:37 -07:00
b07b59e0c8 v4.0.12 2015-07-31 15:34:36 -07:00
9d715e95b7 correctly use urllib in python3 2015-07-31 15:34:24 -07:00
3edaed53aa v4.0.11 2015-07-31 15:20:57 -07:00
865b0bcee9 install python on Windows if not already installed 2015-07-31 15:19:56 -07:00
d440fe912c v4.0.10 2015-07-31 13:27:58 -07:00
627455167f downgrade requests library to v2.6.0 2015-07-31 13:27:04 -07:00
aba89d3948 v4.0.9 2015-07-29 00:04:39 -07:00
18d87118e1 catch exceptions from get_filetype_from_buffer 2015-07-29 00:03:18 -07:00
fd91b9e032 link to wakatime/wakatime#troubleshooting 2015-07-15 13:46:26 -07:00
16b15773bf troubleshooting section in readme 2015-07-15 13:44:07 -07:00
f0b518862a upgrade wakatime cli to v4.1.0 2015-06-29 19:47:04 -07:00
7ee7de70d5 v4.0.8 2015-06-23 18:17:25 -07:00
fb479f8e84 fix offline logging with wakatime cli v4.0.16 2015-06-23 18:15:38 -07:00
7d37193f65 v4.0.7 2015-06-21 10:45:51 -07:00
6bd62b95db allow customizing status bar message in sublime-settings file 2015-06-21 10:42:31 -07:00
abf4a94a59 upgrade wakatime cli to v4.0.15 2015-06-21 10:35:14 -07:00
9337e3173b v4.0.6 2015-05-16 14:38:58 -07:00
57fa4d4d84 upgrade wakatime cli to v4.0.13 2015-05-16 14:38:19 -07:00
9b5c59e677 v4.0.5 2015-05-15 15:34:17 -07:00
71ce25a326 upgrade wakatime cli to v4.0.12 2015-05-15 15:33:03 -07:00
f2f14207f5 use new --alternate-project argument so auto detected project will take priority 2015-05-15 15:32:03 -07:00
ac2ec0e73c v4.0.4 2015-05-12 15:04:39 -07:00
040a76b93c upgrade wakatime cli to v4.0.11 2015-05-12 15:03:23 -07:00
dab0621b97 v4.0.3 2015-05-06 16:35:01 -07:00
675f9ecd69 send cursorpos to wakatime cli 2015-05-06 16:34:15 -07:00
a6f92b9c74 upgrade wakatime cli to v4.0.10 2015-05-06 16:33:32 -07:00
bfcc242d7e upgrade wakatime cli to v4.0.9 2015-05-06 15:45:34 -07:00
762027644f send current cursor line number to wakatime cli 2015-05-06 15:43:41 -07:00
3c4ceb95fa separate active view logic into own function 2015-05-06 14:06:06 -07:00
d6d8bceca0 v4.0.2 2015-05-06 14:01:35 -07:00
acaad2dc83 only send heartbeats for the currently active buffer, for cases where another process modifies files which are open in sublime text 2015-05-06 14:00:33 -07:00
80 changed files with 4499 additions and 2012 deletions

View File

@ -3,6 +3,216 @@ History
-------
6.0.6 (2016-03-06)
++++++++++++++++++
- upgrade wakatime-cli to v4.1.13
- encode TimeZone as utf-8 before adding to headers
- encode X-Machine-Name as utf-8 before adding to headers
6.0.5 (2016-03-06)
++++++++++++++++++
- upgrade wakatime-cli to v4.1.11
- encode machine hostname as Unicode when adding to X-Machine-Name header
6.0.4 (2016-01-15)
++++++++++++++++++
- fix UnicodeDecodeError on ST2 with non-English locale
6.0.3 (2016-01-11)
++++++++++++++++++
- upgrade wakatime-cli core to v4.1.10
- accept 201 or 202 response codes as success from api
- upgrade requests package to v2.9.1
6.0.2 (2016-01-06)
++++++++++++++++++
- upgrade wakatime-cli core to v4.1.9
- improve C# dependency detection
- correctly log exception tracebacks
- log all unknown exceptions to wakatime.log file
- disable urllib3 SSL warning from every request
- detect dependencies from golang files
- use api.wakatime.com for sending heartbeats
6.0.1 (2016-01-01)
++++++++++++++++++
- use embedded python if system python is broken, or doesn't output a version number
- log output from wakatime-cli in ST console when in debug mode
6.0.0 (2015-12-01)
++++++++++++++++++
- use embeddable Python instead of installing on Windows
5.0.1 (2015-10-06)
++++++++++++++++++
- look for python in system PATH again
5.0.0 (2015-10-02)
++++++++++++++++++
- improve logging with levels and log function
- switch registry warnings to debug log level
4.0.20 (2015-10-01)
++++++++++++++++++
- correctly find python binary in non-Windows environments
4.0.19 (2015-10-01)
++++++++++++++++++
- handle case where ST builtin python does not have _winreg or winreg module
4.0.18 (2015-10-01)
++++++++++++++++++
- find python location from windows registry
4.0.17 (2015-10-01)
++++++++++++++++++
- download python in non blocking background thread for Windows machines
4.0.16 (2015-09-29)
++++++++++++++++++
- upgrade wakatime cli to v4.1.8
- fix bug in guess_language function
- improve dependency detection
- default request timeout of 30 seconds
- new --timeout command line argument to change request timeout in seconds
- allow passing command line arguments using sys.argv
- fix entry point for pypi distribution
- new --entity and --entitytype command line arguments
4.0.15 (2015-08-28)
++++++++++++++++++
- upgrade wakatime cli to v4.1.3
- fix local session caching
4.0.14 (2015-08-25)
++++++++++++++++++
- upgrade wakatime cli to v4.1.2
- fix bug in offline caching which prevented heartbeats from being cleaned up
4.0.13 (2015-08-25)
++++++++++++++++++
- upgrade wakatime cli to v4.1.1
- send hostname in X-Machine-Name header
- catch exceptions from pygments.modeline.get_filetype_from_buffer
- upgrade requests package to v2.7.0
- handle non-ASCII characters in import path on Windows, won't fix for Python2
- upgrade argparse to v1.3.0
- move language translations to api server
- move extension rules to api server
- detect correct header file language based on presence of .cpp or .c files named the same as the .h file
4.0.12 (2015-07-31)
++++++++++++++++++
- correctly use urllib in Python3
4.0.11 (2015-07-31)
++++++++++++++++++
- install python if missing on Windows OS
4.0.10 (2015-07-31)
++++++++++++++++++
- downgrade requests library to v2.6.0
4.0.9 (2015-07-29)
++++++++++++++++++
- catch exceptions from pygments.modeline.get_filetype_from_buffer
4.0.8 (2015-06-23)
++++++++++++++++++
- fix offline logging
- limit language detection to known file extensions, unless file contents has a vim modeline
- upgrade wakatime cli to v4.0.16
4.0.7 (2015-06-21)
++++++++++++++++++
- allow customizing status bar message in sublime-settings file
- guess language using multiple methods, then use most accurate guess
- use entity and type for new heartbeats api resource schema
- correctly log message from py.warnings module
- upgrade wakatime cli to v4.0.15
4.0.6 (2015-05-16)
++++++++++++++++++
- fix bug with auto detecting project name
- upgrade wakatime cli to v4.0.13
4.0.5 (2015-05-15)
++++++++++++++++++
- correctly display caller and lineno in log file when debug is true
- project passed with --project argument will always be used
- new --alternate-project argument
- upgrade wakatime cli to v4.0.12
4.0.4 (2015-05-12)
++++++++++++++++++
- reuse SSL connection over multiple processes for improved performance
- upgrade wakatime cli to v4.0.11
4.0.3 (2015-05-06)
++++++++++++++++++
- send cursorpos to wakatime cli
- upgrade wakatime cli to v4.0.10
4.0.2 (2015-05-06)
++++++++++++++++++
- only send heartbeats for the currently active buffer
4.0.1 (2015-05-06)
++++++++++++++++++

View File

@ -1,13 +1,12 @@
sublime-wakatime
================
Fully automatic time tracking for Sublime Text 2 & 3.
Metrics, insights, and time tracking automatically generated from your programming activity.
Installation
------------
Heads Up! For Sublime Text 2 on Windows & Linux, WakaTime depends on [Python](http://www.python.org/getit/) being installed to work correctly.
1. Install [Package Control](https://packagecontrol.io/installation).
2. Using [Package Control](https://packagecontrol.io/docs/usage):
@ -24,8 +23,34 @@ Heads Up! For Sublime Text 2 on Windows & Linux, WakaTime depends on [Python](ht
5. Visit https://wakatime.com/dashboard to see your logged time.
Screen Shots
------------
![Project Overview](https://wakatime.com/static/img/ScreenShots/ScreenShot-2014-10-29.png)
Unresponsive Plugin Warning
---------------------------
In Sublime Text 2, if you get a warning message:
A plugin (WakaTime) may be making Sublime Text unresponsive by taking too long (0.017332s) in its on_modified callback.
To fix this, go to `Preferences > Settings - User` then add the following setting:
`"detect_slow_plugins": false`
Troubleshooting
---------------
First, turn on debug mode in your `WakaTime.sublime-settings` file.
![sublime user settings](https://wakatime.com/static/img/ScreenShots/sublime-wakatime-settings-menu.png)
Add the line: `"debug": true`
Then, open your Sublime Console with `View -> Show Console` to see the plugin executing the wakatime cli process when sending a heartbeat. Also, tail your `$HOME/.wakatime.log` file to debug wakatime cli problems.
For more general troubleshooting information, see [wakatime/wakatime#troubleshooting](https://github.com/wakatime/wakatime#troubleshooting).

View File

@ -7,21 +7,74 @@ Website: https://wakatime.com/
==========================================================="""
__version__ = '4.0.1'
__version__ = '6.0.6'
import sublime
import sublime_plugin
import glob
import os
import platform
import re
import sys
import time
import threading
import urllib
import webbrowser
from datetime import datetime
from subprocess import Popen
from zipfile import ZipFile
from subprocess import Popen, STDOUT, PIPE
try:
import _winreg as winreg # py2
except ImportError:
try:
import winreg # py3
except ImportError:
winreg = None
is_py2 = (sys.version_info[0] == 2)
is_py3 = (sys.version_info[0] == 3)
if is_py2:
def u(text):
if text is None:
return None
try:
text = str(text)
return text.decode('utf-8')
except:
try:
return text.decode(sys.getdefaultencoding())
except:
try:
return unicode(text)
except:
return text
elif is_py3:
def u(text):
if text is None:
return None
if isinstance(text, bytes):
try:
return text.decode('utf-8')
except:
try:
return text.decode(sys.getdefaultencoding())
except:
pass
try:
return str(text)
except:
return text
else:
raise Exception('Unsupported Python version: {0}.{1}.{2}'.format(
sys.version_info[0],
sys.version_info[1],
sys.version_info[2],
))
# globals
@ -40,6 +93,13 @@ LOCK = threading.RLock()
PYTHON_LOCATION = None
# Log Levels
DEBUG = 'DEBUG'
INFO = 'INFO'
WARNING = 'WARNING'
ERROR = 'ERROR'
# add wakatime package to path
sys.path.insert(0, os.path.join(PLUGIN_DIR, 'packages'))
try:
@ -48,6 +108,20 @@ except ImportError:
pass
def log(lvl, message, *args, **kwargs):
try:
if lvl == DEBUG and not SETTINGS.get('debug'):
return
msg = message
if len(args) > 0:
msg = message.format(*args)
elif len(kwargs) > 0:
msg = message.format(**kwargs)
print('[WakaTime] [{lvl}] {msg}'.format(lvl=lvl, msg=msg))
except RuntimeError:
sublime.set_timeout(lambda: log(lvl, message, *args, **kwargs), 0)
def createConfigFile():
"""Creates the .wakatime.cfg INI file in $HOME directory, if it does
not already exist.
@ -92,35 +166,130 @@ def prompt_api_key():
window.show_input_panel('[WakaTime] Enter your wakatime.com api key:', default_key, got_key, None, None)
return True
else:
print('[WakaTime] Error: Could not prompt for api key because no window found.')
log(ERROR, 'Could not prompt for api key because no window found.')
return False
def python_binary():
global PYTHON_LOCATION
if PYTHON_LOCATION is not None:
return PYTHON_LOCATION
# look for python in PATH and common install locations
paths = [
"pythonw",
"python",
"/usr/local/bin/python",
"/usr/bin/python",
os.path.join(os.path.expanduser('~'), '.wakatime', 'python'),
None,
'/',
'/usr/local/bin/',
'/usr/bin/',
]
for path in paths:
try:
Popen([path, '--version'])
PYTHON_LOCATION = path
path = find_python_in_folder(path)
if path is not None:
set_python_binary_location(path)
return path
except:
pass
for path in glob.iglob('/python*'):
path = os.path.realpath(os.path.join(path, 'pythonw'))
try:
Popen([path, '--version'])
PYTHON_LOCATION = path
# look for python in windows registry
path = find_python_from_registry(r'SOFTWARE\Python\PythonCore')
if path is not None:
set_python_binary_location(path)
return path
path = find_python_from_registry(r'SOFTWARE\Wow6432Node\Python\PythonCore')
if path is not None:
set_python_binary_location(path)
return path
return None
def set_python_binary_location(path):
global PYTHON_LOCATION
PYTHON_LOCATION = path
log(DEBUG, 'Found Python at: {0}'.format(path))
def find_python_from_registry(location, reg=None):
if platform.system() != 'Windows' or winreg is None:
return None
if reg is None:
path = find_python_from_registry(location, reg=winreg.HKEY_CURRENT_USER)
if path is None:
path = find_python_from_registry(location, reg=winreg.HKEY_LOCAL_MACHINE)
return path
val = None
sub_key = 'InstallPath'
compiled = re.compile(r'^\d+\.\d+$')
try:
with winreg.OpenKey(reg, location) as handle:
versions = []
try:
for index in range(1024):
version = winreg.EnumKey(handle, index)
try:
if compiled.search(version):
versions.append(version)
except re.error:
pass
except EnvironmentError:
pass
versions.sort(reverse=True)
for version in versions:
try:
path = winreg.QueryValue(handle, version + '\\' + sub_key)
if path is not None:
path = find_python_in_folder(path)
if path is not None:
log(DEBUG, 'Found python from {reg}\\{key}\\{version}\\{sub_key}.'.format(
reg=reg,
key=location,
version=version,
sub_key=sub_key,
))
return path
except WindowsError:
log(DEBUG, 'Could not read registry value "{reg}\\{key}\\{version}\\{sub_key}".'.format(
reg=reg,
key=location,
version=version,
sub_key=sub_key,
))
except WindowsError:
if SETTINGS.get('debug'):
log(DEBUG, 'Could not read registry value "{reg}\\{key}".'.format(
reg=reg,
key=location,
))
return val
def find_python_in_folder(folder, headless=True):
pattern = re.compile(r'\d+\.\d+')
path = 'python'
if folder is not None:
path = os.path.realpath(os.path.join(folder, 'python'))
if headless:
path = u(path) + u('w')
log(DEBUG, u('Looking for Python at: {0}').format(path))
try:
process = Popen([path, '--version'], stdout=PIPE, stderr=STDOUT)
output, err = process.communicate()
output = u(output).strip()
retcode = process.poll()
log(DEBUG, u('Python Version Output: {0}').format(output))
if not retcode and pattern.search(output):
return path
except:
pass
except:
log(DEBUG, u('Python Version Output: {0}').format(u(sys.exc_info()[1])))
if headless:
path = find_python_in_folder(folder, headless=False)
if path is not None:
return path
return None
@ -132,7 +301,7 @@ def obfuscate_apikey(command_list):
apikey_index = num + 1
break
if apikey_index is not None and apikey_index < len(cmd):
cmd[apikey_index] = '********-****-****-****-********' + cmd[apikey_index][-4:]
cmd[apikey_index] = 'XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXX' + cmd[apikey_index][-4:]
return cmd
@ -173,6 +342,16 @@ def find_project_from_folders(folders, current_file):
return os.path.basename(folder) if folder else None
def is_view_active(view):
if view:
active_window = sublime.active_window()
if active_window:
active_view = active_window.active_view()
if active_view:
return active_view.buffer_id() == view.buffer_id()
return False
def handle_heartbeat(view, is_write=False):
window = view.window()
if window is not None:
@ -184,6 +363,8 @@ def handle_heartbeat(view, is_write=False):
class SendHeartbeatThread(threading.Thread):
"""Non-blocking thread for sending heartbeats to api.
"""
def __init__(self, target_file, view, is_write=False, project=None, folders=None, force=False):
threading.Thread.__init__(self)
@ -197,6 +378,7 @@ class SendHeartbeatThread(threading.Thread):
self.api_key = SETTINGS.get('api_key', '')
self.ignore = SETTINGS.get('ignore', [])
self.last_heartbeat = LAST_HEARTBEAT.copy()
self.cursorpos = view.sel()[0].begin() if view.sel() else None
self.view = view
def run(self):
@ -208,7 +390,7 @@ class SendHeartbeatThread(threading.Thread):
def send_heartbeat(self):
if not self.api_key:
print('[WakaTime] Error: missing api key.')
log(ERROR, 'missing api key.')
return
ua = 'sublime/%d sublime-wakatime/%s' % (ST_VERSION, __version__)
cmd = [
@ -221,27 +403,39 @@ class SendHeartbeatThread(threading.Thread):
if self.is_write:
cmd.append('--write')
if self.project and self.project.get('name'):
cmd.extend(['--project', self.project.get('name')])
cmd.extend(['--alternate-project', self.project.get('name')])
elif self.folders:
project_name = find_project_from_folders(self.folders, self.target_file)
if project_name:
cmd.extend(['--project', project_name])
cmd.extend(['--alternate-project', project_name])
if self.cursorpos is not None:
cmd.extend(['--cursorpos', '{0}'.format(self.cursorpos)])
for pattern in self.ignore:
cmd.extend(['--ignore', pattern])
if self.debug:
cmd.append('--verbose')
if python_binary():
cmd.insert(0, python_binary())
if self.debug:
print('[WakaTime] %s' % ' '.join(obfuscate_apikey(cmd)))
if platform.system() == 'Windows':
Popen(cmd, shell=False)
else:
with open(os.path.join(os.path.expanduser('~'), '.wakatime.log'), 'a') as stderr:
Popen(cmd, stderr=stderr)
self.sent()
log(DEBUG, ' '.join(obfuscate_apikey(cmd)))
try:
if not self.debug:
Popen(cmd)
self.sent()
else:
process = Popen(cmd, stdout=PIPE, stderr=STDOUT)
output, err = process.communicate()
output = u(output)
retcode = process.poll()
if (not retcode or retcode == 102) and not output:
self.sent()
if retcode:
log(DEBUG if retcode == 102 else ERROR, 'wakatime-core exited with status: {0}'.format(retcode))
if output:
log(ERROR, u('wakatime-core output: {0}').format(output))
except:
log(ERROR, u(sys.exc_info()[1]))
else:
print('[WakaTime] Error: Unable to find python binary.')
log(ERROR, 'Unable to find python binary.')
def sent(self):
sublime.set_timeout(self.set_status_bar, 0)
@ -249,7 +443,7 @@ class SendHeartbeatThread(threading.Thread):
def set_status_bar(self):
if SETTINGS.get('status_bar_message'):
self.view.set_status('wakatime', 'WakaTime active {0}'.format(datetime.now().strftime('%I:%M %p')))
self.view.set_status('wakatime', datetime.now().strftime(SETTINGS.get('status_bar_message_fmt')))
def set_last_heartbeat(self):
global LAST_HEARTBEAT
@ -260,15 +454,57 @@ class SendHeartbeatThread(threading.Thread):
}
class DownloadPython(threading.Thread):
"""Non-blocking thread for extracting embeddable Python on Windows machines.
"""
def run(self):
log(INFO, 'Downloading embeddable Python...')
ver = '3.5.0'
arch = 'amd64' if platform.architecture()[0] == '64bit' else 'win32'
url = 'https://www.python.org/ftp/python/{ver}/python-{ver}-embed-{arch}.zip'.format(
ver=ver,
arch=arch,
)
if not os.path.exists(os.path.join(os.path.expanduser('~'), '.wakatime')):
os.makedirs(os.path.join(os.path.expanduser('~'), '.wakatime'))
zip_file = os.path.join(os.path.expanduser('~'), '.wakatime', 'python.zip')
try:
urllib.urlretrieve(url, zip_file)
except AttributeError:
urllib.request.urlretrieve(url, zip_file)
log(INFO, 'Extracting Python...')
with ZipFile(zip_file) as zf:
path = os.path.join(os.path.expanduser('~'), '.wakatime', 'python')
zf.extractall(path)
try:
os.remove(zip_file)
except:
pass
log(INFO, 'Finished extracting Python.')
def plugin_loaded():
global SETTINGS
print('[WakaTime] Initializing WakaTime plugin v%s' % __version__)
if not python_binary():
sublime.error_message("Unable to find Python binary!\nWakaTime needs Python to work correctly.\n\nGo to https://www.python.org/downloads")
return
log(INFO, 'Initializing WakaTime plugin v%s' % __version__)
SETTINGS = sublime.load_settings(SETTINGS_FILE)
if not python_binary():
log(WARNING, 'Python binary not found.')
if platform.system() == 'Windows':
thread = DownloadPython()
thread.start()
else:
sublime.error_message("Unable to find Python binary!\nWakaTime needs Python to work correctly.\n\nGo to https://www.python.org/downloads")
return
after_loaded()
@ -288,10 +524,12 @@ class WakatimeListener(sublime_plugin.EventListener):
handle_heartbeat(view, is_write=True)
def on_selection_modified(self, view):
handle_heartbeat(view)
if is_view_active(view):
handle_heartbeat(view)
def on_modified(self, view):
handle_heartbeat(view)
if is_view_active(view):
handle_heartbeat(view)
class WakatimeDashboardCommand(sublime_plugin.ApplicationCommand):

View File

@ -16,5 +16,8 @@
// Status bar message. Set to false to hide status bar message.
// Defaults to true.
"status_bar_message": true
"status_bar_message": true,
// Status bar message format.
"status_bar_message_fmt": "WakaTime active %I:%M %p"
}

View File

@ -1,9 +1,9 @@
__title__ = 'wakatime'
__description__ = 'Common interface to the WakaTime api.'
__url__ = 'https://github.com/wakatime/wakatime'
__version_info__ = ('4', '0', '8')
__version_info__ = ('4', '1', '13')
__version__ = '.'.join(__version_info__)
__author__ = 'Alan Hamlett'
__author_email__ = 'alan@wakatime.com'
__license__ = 'BSD'
__copyright__ = 'Copyright 2014 Alan Hamlett'
__copyright__ = 'Copyright 2016 Alan Hamlett'

View File

@ -14,4 +14,4 @@
__all__ = ['main']
from .base import main
from .main import execute

View File

@ -11,8 +11,25 @@
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import wakatime
# get path to local wakatime package
package_folder = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# add local wakatime package to sys.path
sys.path.insert(0, package_folder)
# import local wakatime package
try:
import wakatime
except (TypeError, ImportError):
# on Windows, non-ASCII characters in import path can be fixed using
# the script path from sys.argv[0].
# More info at https://github.com/wakatime/wakatime/issues/32
package_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.insert(0, package_folder)
import wakatime
if __name__ == '__main__':
sys.exit(wakatime.main(sys.argv))
sys.exit(wakatime.execute(sys.argv[1:]))

View File

@ -17,32 +17,49 @@ is_py2 = (sys.version_info[0] == 2)
is_py3 = (sys.version_info[0] == 3)
if is_py2:
if is_py2: # pragma: nocover
def u(text):
if text is None:
return None
try:
return text.decode('utf-8')
except:
try:
return unicode(text)
return text.decode(sys.getdefaultencoding())
except:
return text
try:
return unicode(text)
except:
return text
open = codecs.open
basestring = basestring
elif is_py3:
elif is_py3: # pragma: nocover
def u(text):
if text is None:
return None
if isinstance(text, bytes):
return text.decode('utf-8')
return str(text)
try:
return text.decode('utf-8')
except:
try:
return text.decode(sys.getdefaultencoding())
except:
pass
try:
return str(text)
except:
return text
open = open
basestring = (str, bytes)
try:
from importlib import import_module
except ImportError:
except ImportError: # pragma: nocover
def _resolve_name(name, package, level):
"""Return the absolute name of the module to be imported."""
if not hasattr(package, 'rindex'):

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
"""
wakatime.constants
~~~~~~~~~~~~~~~~~~
Constant variable definitions.
:copyright: (c) 2016 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
SUCCESS = 0
API_ERROR = 102
CONFIG_FILE_PARSE_ERROR = 103
AUTH_ERROR = 104
UNKNOWN_ERROR = 105

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages
~~~~~~~~~~~~~~~~~~
wakatime.dependencies
~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from a source code file.
@ -10,9 +10,12 @@
"""
import logging
import re
import sys
import traceback
from ..compat import u, open, import_module
from ..exceptions import NotYetImplemented
log = logging.getLogger('WakaTime')
@ -23,26 +26,28 @@ class TokenParser(object):
language, inherit from this class and implement the :meth:`parse` method
to return a list of dependency strings.
"""
source_file = None
lexer = None
dependencies = []
tokens = []
exclude = []
def __init__(self, source_file, lexer=None):
self._tokens = None
self.dependencies = []
self.source_file = source_file
self.lexer = lexer
self.exclude = [re.compile(x, re.IGNORECASE) for x in self.exclude]
@property
def tokens(self):
if self._tokens is None:
self._tokens = self._extract_tokens()
return self._tokens
def parse(self, tokens=[]):
""" Should return a list of dependencies.
"""
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
raise Exception('Not yet implemented.')
raise NotYetImplemented()
def append(self, dep, truncate=False, separator=None, truncate_to=None,
strip_whitespace=True):
if dep == 'as':
print('***************** as')
self._save_dependency(
dep,
truncate=truncate,
@ -51,10 +56,21 @@ class TokenParser(object):
strip_whitespace=strip_whitespace,
)
def partial(self, token):
return u(token).split('.')[-1]
def _extract_tokens(self):
if self.lexer:
with open(self.source_file, 'r', encoding='utf-8') as fh:
return self.lexer.get_tokens_unprocessed(fh.read(512000))
try:
with open(self.source_file, 'r', encoding='utf-8') as fh:
return self.lexer.get_tokens_unprocessed(fh.read(512000))
except:
pass
try:
with open(self.source_file, 'r', encoding=sys.getfilesystemencoding()) as fh:
return self.lexer.get_tokens_unprocessed(fh.read(512000))
except:
pass
return []
def _save_dependency(self, dep, truncate=False, separator=None,
@ -64,13 +80,21 @@ class TokenParser(object):
separator = u('.')
separator = u(separator)
dep = dep.split(separator)
if truncate_to is None or truncate_to < 0 or truncate_to > len(dep) - 1:
truncate_to = len(dep) - 1
dep = dep[0] if len(dep) == 1 else separator.join(dep[0:truncate_to])
if truncate_to is None or truncate_to < 1:
truncate_to = 1
if truncate_to > len(dep):
truncate_to = len(dep)
dep = dep[0] if len(dep) == 1 else separator.join(dep[:truncate_to])
if strip_whitespace:
dep = dep.strip()
if dep:
self.dependencies.append(dep)
if dep and (not separator or not dep.startswith(separator)):
should_exclude = False
for compiled in self.exclude:
if compiled.search(dep):
should_exclude = True
break
if not should_exclude:
self.dependencies.append(dep)
class DependencyParser(object):
@ -83,7 +107,7 @@ class DependencyParser(object):
self.lexer = lexer
if self.lexer:
module_name = self.lexer.__module__.split('.')[-1]
module_name = self.lexer.__module__.rsplit('.', 1)[-1]
class_name = self.lexer.__class__.__name__.replace('Lexer', 'Parser', 1)
else:
module_name = 'unknown'

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.c_cpp
~~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from C++ code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
class CppParser(TokenParser):
exclude = [
r'^stdio\.h$',
r'^stdlib\.h$',
r'^string\.h$',
r'^time\.h$',
]
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if self.partial(token) == 'Preproc':
self._process_preproc(token, content)
else:
self._process_other(token, content)
def _process_preproc(self, token, content):
if content.strip().startswith('include ') or content.strip().startswith("include\t"):
content = content.replace('include', '', 1).strip().strip('"').strip('<').strip('>').strip()
self.append(content)
def _process_other(self, token, content):
pass
class CParser(TokenParser):
exclude = [
r'^stdio\.h$',
r'^stdlib\.h$',
r'^string\.h$',
r'^time\.h$',
]
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if self.partial(token) == 'Preproc':
self._process_preproc(token, content)
else:
self._process_other(token, content)
def _process_preproc(self, token, content):
if content.strip().startswith('include ') or content.strip().startswith("include\t"):
content = content.replace('include', '', 1).strip().strip('"').strip('<').strip('>').strip()
self.append(content)
def _process_other(self, token, content):
pass

View File

@ -26,10 +26,8 @@ class JsonParser(TokenParser):
state = None
level = 0
def parse(self, tokens=[]):
def parse(self):
self._process_file_name(os.path.basename(self.source_file))
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies

View File

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.dotnet
~~~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from .NET code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
from ..compat import u
class CSharpParser(TokenParser):
exclude = [
r'^system$',
r'^microsoft$',
]
state = None
buffer = u('')
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if self.partial(token) == 'Keyword':
self._process_keyword(token, content)
if self.partial(token) == 'Namespace' or self.partial(token) == 'Name':
self._process_namespace(token, content)
elif self.partial(token) == 'Punctuation':
self._process_punctuation(token, content)
else:
self._process_other(token, content)
def _process_keyword(self, token, content):
if content == 'using':
self.state = 'import'
self.buffer = u('')
def _process_namespace(self, token, content):
if self.state == 'import':
if u(content) != u('import') and u(content) != u('package') and u(content) != u('namespace') and u(content) != u('static'):
if u(content) == u(';'): # pragma: nocover
self._process_punctuation(token, content)
else:
self.buffer += u(content)
def _process_punctuation(self, token, content):
if self.state == 'import':
if u(content) == u(';'):
self.append(self.buffer, truncate=True)
self.buffer = u('')
self.state = None
elif u(content) == u('='):
self.buffer = u('')
else:
self.buffer += u(content)
def _process_other(self, token, content):
pass

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.go
~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from Go code.
:copyright: (c) 2016 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
class GoParser(TokenParser):
state = None
parens = 0
aliases = 0
exclude = [
r'^"fmt"$',
]
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if self.partial(token) == 'Namespace':
self._process_namespace(token, content)
elif self.partial(token) == 'Punctuation':
self._process_punctuation(token, content)
elif self.partial(token) == 'String':
self._process_string(token, content)
elif self.partial(token) == 'Text':
self._process_text(token, content)
elif self.partial(token) == 'Other':
self._process_other(token, content)
else:
self._process_misc(token, content)
def _process_namespace(self, token, content):
self.state = content
self.parens = 0
self.aliases = 0
def _process_string(self, token, content):
if self.state == 'import':
self.append(content, truncate=False)
def _process_punctuation(self, token, content):
if content == '(':
self.parens += 1
elif content == ')':
self.parens -= 1
elif content == '.':
self.aliases += 1
else:
self.state = None
def _process_text(self, token, content):
if self.state == 'import':
if content == "\n" and self.parens <= 0:
self.state = None
self.parens = 0
self.aliases = 0
else:
self.state = None
def _process_other(self, token, content):
if self.state == 'import':
self.aliases += 1
else:
self.state = None
def _process_misc(self, token, content):
self.state = None

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.java
~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from Java code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
from ..compat import u
class JavaParser(TokenParser):
exclude = [
r'^java\.',
r'^javax\.',
r'^import$',
r'^package$',
r'^namespace$',
r'^static$',
]
state = None
buffer = u('')
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if self.partial(token) == 'Namespace':
self._process_namespace(token, content)
if self.partial(token) == 'Name':
self._process_name(token, content)
elif self.partial(token) == 'Attribute':
self._process_attribute(token, content)
elif self.partial(token) == 'Operator':
self._process_operator(token, content)
else:
self._process_other(token, content)
def _process_namespace(self, token, content):
if u(content) == u('import'):
self.state = 'import'
elif self.state == 'import':
keywords = [
u('package'),
u('namespace'),
u('static'),
]
if u(content) in keywords:
return
self.buffer = u('{0}{1}').format(self.buffer, u(content))
elif self.state == 'import-finished':
content = content.split(u('.'))
if len(content) == 1:
self.append(content[0])
elif len(content) > 1:
if len(content[0]) == 3:
content = content[1:]
if content[-1] == u('*'):
content = content[:len(content) - 1]
if len(content) == 1:
self.append(content[0])
elif len(content) > 1:
self.append(u('.').join(content[:2]))
self.state = None
def _process_name(self, token, content):
if self.state == 'import':
self.buffer = u('{0}{1}').format(self.buffer, u(content))
def _process_attribute(self, token, content):
if self.state == 'import':
self.buffer = u('{0}{1}').format(self.buffer, u(content))
def _process_operator(self, token, content):
if u(content) == u(';'):
self.state = 'import-finished'
self._process_namespace(token, self.buffer)
self.state = None
self.buffer = u('')
elif self.state == 'import':
self.buffer = u('{0}{1}').format(self.buffer, u(content))
def _process_other(self, token, content):
pass

View File

@ -17,15 +17,13 @@ class PhpParser(TokenParser):
state = None
parens = 0
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if u(token).split('.')[-1] == 'Keyword':
if self.partial(token) == 'Keyword':
self._process_keyword(token, content)
elif u(token) == 'Token.Literal.String.Single' or u(token) == 'Token.Literal.String.Double':
self._process_literal_string(token, content)
@ -33,9 +31,9 @@ class PhpParser(TokenParser):
self._process_name(token, content)
elif u(token) == 'Token.Name.Function':
self._process_function(token, content)
elif u(token).split('.')[-1] == 'Punctuation':
elif self.partial(token) == 'Punctuation':
self._process_punctuation(token, content)
elif u(token).split('.')[-1] == 'Text':
elif self.partial(token) == 'Text':
self._process_text(token, content)
else:
self._process_other(token, content)
@ -63,10 +61,10 @@ class PhpParser(TokenParser):
def _process_literal_string(self, token, content):
if self.state == 'include':
if content != '"':
if content != '"' and content != "'":
content = content.strip()
if u(token) == 'Token.Literal.String.Double':
content = u('"{0}"').format(content)
content = u("'{0}'").format(content)
self.append(content)
self.state = None

View File

@ -10,33 +10,30 @@
"""
from . import TokenParser
from ..compat import u
class PythonParser(TokenParser):
state = None
parens = 0
nonpackage = False
exclude = [
r'^os$',
r'^sys\.',
]
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if u(token).split('.')[-1] == 'Namespace':
if self.partial(token) == 'Namespace':
self._process_namespace(token, content)
elif u(token).split('.')[-1] == 'Name':
self._process_name(token, content)
elif u(token).split('.')[-1] == 'Word':
self._process_word(token, content)
elif u(token).split('.')[-1] == 'Operator':
elif self.partial(token) == 'Operator':
self._process_operator(token, content)
elif u(token).split('.')[-1] == 'Punctuation':
elif self.partial(token) == 'Punctuation':
self._process_punctuation(token, content)
elif u(token).split('.')[-1] == 'Text':
elif self.partial(token) == 'Text':
self._process_text(token, content)
else:
self._process_other(token, content)
@ -50,38 +47,6 @@ class PythonParser(TokenParser):
else:
self._process_import(token, content)
def _process_name(self, token, content):
if self.state is not None:
if self.nonpackage:
self.nonpackage = False
else:
if self.state == 'from':
self.append(content, truncate=True, truncate_to=0)
if self.state == 'from-2' and content != 'import':
self.append(content, truncate=True, truncate_to=0)
elif self.state == 'import':
self.append(content, truncate=True, truncate_to=0)
elif self.state == 'import-2':
self.append(content, truncate=True, truncate_to=0)
else:
self.state = None
def _process_word(self, token, content):
if self.state is not None:
if self.nonpackage:
self.nonpackage = False
else:
if self.state == 'from':
self.append(content, truncate=True, truncate_to=0)
if self.state == 'from-2' and content != 'import':
self.append(content, truncate=True, truncate_to=0)
elif self.state == 'import':
self.append(content, truncate=True, truncate_to=0)
elif self.state == 'import-2':
self.append(content, truncate=True, truncate_to=0)
else:
self.state = None
def _process_operator(self, token, content):
if self.state is not None:
if content == '.':
@ -106,15 +71,15 @@ class PythonParser(TokenParser):
def _process_import(self, token, content):
if not self.nonpackage:
if self.state == 'from':
self.append(content, truncate=True, truncate_to=0)
self.append(content, truncate=True, truncate_to=1)
self.state = 'from-2'
elif self.state == 'from-2' and content != 'import':
self.append(content, truncate=True, truncate_to=0)
self.append(content, truncate=True, truncate_to=1)
elif self.state == 'import':
self.append(content, truncate=True, truncate_to=0)
self.append(content, truncate=True, truncate_to=1)
self.state = 'import-2'
elif self.state == 'import-2':
self.append(content, truncate=True, truncate_to=0)
self.append(content, truncate=True, truncate_to=1)
else:
self.state = None
self.nonpackage = False

View File

@ -71,9 +71,7 @@ KEYWORDS = [
class LassoJavascriptParser(TokenParser):
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
@ -99,9 +97,7 @@ class HtmlDjangoParser(TokenParser):
current_attr = None
current_attr_value = None
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
def parse(self):
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies

View File

@ -22,7 +22,7 @@ FILES = {
class UnknownParser(TokenParser):
def parse(self, tokens=[]):
def parse(self):
self._process_file_name(os.path.basename(self.source_file))
return self.dependencies

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
"""
wakatime.exceptions
~~~~~~~~~~~~~~~~~~~
Custom exceptions.
:copyright: (c) 2015 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
class NotYetImplemented(Exception):
"""This method needs to be implemented."""

View File

@ -1,37 +0,0 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.c_cpp
~~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from C++ code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
from ..compat import u
class CppParser(TokenParser):
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if u(token).split('.')[-1] == 'Preproc':
self._process_preproc(token, content)
else:
self._process_other(token, content)
def _process_preproc(self, token, content):
if content.strip().startswith('include ') or content.strip().startswith("include\t"):
content = content.replace('include', '', 1).strip()
self.append(content)
def _process_other(self, token, content):
pass

View File

@ -1,36 +0,0 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.dotnet
~~~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from .NET code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
from ..compat import u
class CSharpParser(TokenParser):
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if u(token).split('.')[-1] == 'Namespace':
self._process_namespace(token, content)
else:
self._process_other(token, content)
def _process_namespace(self, token, content):
if content != 'import' and content != 'package' and content != 'namespace':
self.append(content, truncate=True)
def _process_other(self, token, content):
pass

View File

@ -1,36 +0,0 @@
# -*- coding: utf-8 -*-
"""
wakatime.languages.java
~~~~~~~~~~~~~~~~~~~~~~~
Parse dependencies from Java code.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from . import TokenParser
from ..compat import u
class JavaParser(TokenParser):
def parse(self, tokens=[]):
if not tokens and not self.tokens:
self.tokens = self._extract_tokens()
for index, token, content in self.tokens:
self._process_token(token, content)
return self.dependencies
def _process_token(self, token, content):
if u(token).split('.')[-1] == 'Namespace':
self._process_namespace(token, content)
else:
self._process_other(token, content)
def _process_namespace(self, token, content):
if content != 'import' and content != 'package' and content != 'namespace':
self.append(content, truncate=True)
def _process_other(self, token, content):
pass

View File

@ -9,28 +9,31 @@
:license: BSD, see LICENSE for more details.
"""
import inspect
import logging
import os
import sys
import traceback
from .packages import simplejson as json
from .compat import u
from .packages.requests.packages import urllib3
try:
from collections import OrderedDict
except ImportError:
from collections import OrderedDict # pragma: nocover
except ImportError: # pragma: nocover
from .packages.ordereddict import OrderedDict
try:
from .packages import simplejson as json # pragma: nocover
except (ImportError, SyntaxError): # pragma: nocover
import json
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
obj = bytes.decode(obj)
if isinstance(obj, bytes): # pragma: nocover
obj = u(obj)
return json.dumps(obj)
try:
try: # pragma: nocover
encoded = super(CustomEncoder, self).default(obj)
except UnicodeDecodeError:
except UnicodeDecodeError: # pragma: nocover
obj = u(obj)
encoded = super(CustomEncoder, self).default(obj)
return encoded
@ -38,37 +41,46 @@ class CustomEncoder(json.JSONEncoder):
class JsonFormatter(logging.Formatter):
def setup(self, timestamp, isWrite, targetFile, version, plugin):
def setup(self, timestamp, isWrite, entity, version, plugin, verbose,
warnings=False):
self.timestamp = timestamp
self.isWrite = isWrite
self.targetFile = targetFile
self.entity = entity
self.version = version
self.plugin = plugin
self.verbose = verbose
self.warnings = warnings
def format(self, record):
def format(self, record, *args):
data = OrderedDict([
('now', self.formatTime(record, self.datefmt)),
])
try:
data['package'] = inspect.stack()[9][0].f_globals.get('__package__')
data['lineno'] = inspect.stack()[9][2]
except:
pass
data['version'] = self.version
data['plugin'] = self.plugin
data['time'] = self.timestamp
data['isWrite'] = self.isWrite
data['file'] = self.targetFile
if self.verbose:
data['caller'] = record.pathname
data['lineno'] = record.lineno
data['isWrite'] = self.isWrite
data['file'] = self.entity
if not self.isWrite:
del data['isWrite']
data['level'] = record.levelname
data['message'] = record.msg
data['message'] = record.getMessage() if self.warnings else record.msg
if not self.plugin:
del data['plugin']
if not self.isWrite:
del data['isWrite']
return CustomEncoder().encode(data)
def formatException(self, exc_info):
return sys.exec_info[2].format_exc()
def traceback_formatter(*args, **kwargs):
if 'level' in kwargs and (kwargs['level'].lower() == 'warn' or kwargs['level'].lower() == 'warning'):
logging.getLogger('WakaTime').warning(traceback.format_exc())
elif 'level' in kwargs and kwargs['level'].lower() == 'info':
logging.getLogger('WakaTime').info(traceback.format_exc())
elif 'level' in kwargs and kwargs['level'].lower() == 'debug':
logging.getLogger('WakaTime').debug(traceback.format_exc())
else:
logging.getLogger('WakaTime').error(traceback.format_exc())
def set_log_level(logger, args):
@ -79,20 +91,11 @@ def set_log_level(logger, args):
def setup_logging(args, version):
logging.captureWarnings(True)
urllib3.disable_warnings()
logger = logging.getLogger('WakaTime')
for handler in logger.handlers:
logger.removeHandler(handler)
set_log_level(logger, args)
if len(logger.handlers) > 0:
formatter = JsonFormatter(datefmt='%Y/%m/%d %H:%M:%S %z')
formatter.setup(
timestamp=args.timestamp,
isWrite=args.isWrite,
targetFile=args.targetFile,
version=version,
plugin=args.plugin,
)
logger.handlers[0].setFormatter(formatter)
return logger
logfile = args.logfile
if not logfile:
logfile = '~/.wakatime.log'
@ -101,11 +104,33 @@ def setup_logging(args, version):
formatter.setup(
timestamp=args.timestamp,
isWrite=args.isWrite,
targetFile=args.targetFile,
entity=args.entity,
version=version,
plugin=args.plugin,
verbose=args.verbose,
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logging.getLogger('py.warnings').addHandler(handler)
# add custom traceback logging method
logger.traceback = traceback_formatter
warnings_formatter = JsonFormatter(datefmt='%Y/%m/%d %H:%M:%S %z')
warnings_formatter.setup(
timestamp=args.timestamp,
isWrite=args.isWrite,
entity=args.entity,
version=version,
plugin=args.plugin,
verbose=args.verbose,
warnings=True,
)
warnings_handler = logging.FileHandler(os.path.expanduser(logfile))
warnings_handler.setFormatter(warnings_formatter)
logging.getLogger('py.warnings').addHandler(warnings_handler)
try:
logging.captureWarnings(True)
except AttributeError: # pragma: nocover
pass # Python >= 2.7 is needed to capture warnings
return logger

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
wakatime.base
wakatime.main
~~~~~~~~~~~~~
wakatime module entry point.
@ -19,27 +19,40 @@ import re
import sys
import time
import traceback
import socket
try:
import ConfigParser as configparser
except ImportError:
except ImportError: # pragma: nocover
import configparser
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages'))
pwd = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(pwd))
sys.path.insert(0, os.path.join(pwd, 'packages'))
from .__about__ import __version__
from .compat import u, open, is_py3
from .offlinequeue import Queue
from .constants import (
API_ERROR,
AUTH_ERROR,
CONFIG_FILE_PARSE_ERROR,
SUCCESS,
UNKNOWN_ERROR,
)
from .logger import setup_logging
from .project import find_project
from .stats import get_file_stats
from .offlinequeue import Queue
from .packages import argparse
from .packages import simplejson as json
from .packages import requests
from .packages.requests.exceptions import RequestException
from .project import get_project_info
from .session_cache import SessionCache
from .stats import get_file_stats
try:
from .packages import simplejson as json # pragma: nocover
except (ImportError, SyntaxError): # pragma: nocover
import json
try:
from .packages import tzlocal
except:
except: # pragma: nocover
from .packages import tzlocal3 as tzlocal
@ -49,49 +62,14 @@ log = logging.getLogger('WakaTime')
class FileAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = os.path.realpath(values)
try:
if os.path.isfile(values):
values = os.path.realpath(values)
except: # pragma: nocover
pass
setattr(namespace, self.dest, values)
def upgradeConfigFile(configFile):
"""For backwards-compatibility, upgrade the existing config file
to work with configparser and rename from .wakatime.conf to .wakatime.cfg.
"""
if os.path.isfile(configFile):
# if upgraded cfg file already exists, don't overwrite it
return
oldConfig = os.path.join(os.path.expanduser('~'), '.wakatime.conf')
try:
configs = {
'ignore': [],
}
with open(oldConfig, 'r', encoding='utf-8') as fh:
for line in fh.readlines():
line = line.split('=', 1)
if len(line) == 2 and line[0].strip() and line[1].strip():
if line[0].strip() == 'ignore':
configs['ignore'].append(line[1].strip())
else:
configs[line[0].strip()] = line[1].strip()
with open(configFile, 'w', encoding='utf-8') as fh:
fh.write("[settings]\n")
for name, value in configs.items():
if isinstance(value, list):
fh.write("%s=\n" % name)
for item in value:
fh.write(" %s\n" % item)
else:
fh.write("%s = %s\n" % (name, value))
os.remove(oldConfig)
except IOError:
pass
def parseConfigFile(configFile=None):
"""Returns a configparser.SafeConfigParser instance with configs
read from the config file. Default location of the config file is
@ -101,8 +79,6 @@ def parseConfigFile(configFile=None):
if not configFile:
configFile = os.path.join(os.path.expanduser('~'), '.wakatime.cfg')
upgradeConfigFile(configFile)
configs = configparser.SafeConfigParser()
try:
with open(configFile, 'r', encoding='utf-8') as fh:
@ -116,23 +92,21 @@ def parseConfigFile(configFile=None):
return configs
def parseArguments(argv):
def parseArguments():
"""Parse command line arguments and configs from ~/.wakatime.cfg.
Command line arguments take precedence over config file settings.
Returns instances of ArgumentParser and SafeConfigParser.
"""
try:
sys.argv
except AttributeError:
sys.argv = argv
# define supported command line arguments
parser = argparse.ArgumentParser(
description='Common interface for the WakaTime api.')
parser.add_argument('--file', dest='targetFile', metavar='file',
action=FileAction, required=True,
help='absolute path to file for current heartbeat')
parser.add_argument('--entity', dest='entity', metavar='FILE',
action=FileAction,
help='absolute path to file for the heartbeat; can also be a '+
'url, domain, or app when --entitytype is not file')
parser.add_argument('--file', dest='file', action=FileAction,
help=argparse.SUPPRESS)
parser.add_argument('--key', dest='key',
help='your wakatime api key; uses api_key from '+
'~/.wakatime.conf by default')
@ -147,14 +121,23 @@ def parseArguments(argv):
type=float,
help='optional floating-point unix epoch timestamp; '+
'uses current time by default')
parser.add_argument('--notfile', dest='notfile', action='store_true',
help='when set, will accept any value for the file. for example, '+
'a domain name or other item you want to log time towards.')
parser.add_argument('--lineno', dest='lineno',
help='optional line number; current line being edited')
parser.add_argument('--cursorpos', dest='cursorpos',
help='optional cursor position in the current file')
parser.add_argument('--entitytype', dest='entity_type',
help='entity type for this heartbeat. can be one of "file", '+
'"url", "domain", or "app"; defaults to file.')
parser.add_argument('--proxy', dest='proxy',
help='optional https proxy url; for example: '+
'https://user:pass@localhost:8080')
parser.add_argument('--project', dest='project_name',
help='optional project name; auto-discovered project takes priority')
'https://user:pass@localhost:8080')
parser.add_argument('--project', dest='project',
help='optional project name')
parser.add_argument('--alternate-project', dest='alternate_project',
help='optional alternate project name; auto-discovered project '+
'takes priority')
parser.add_argument('--hostname', dest='hostname', help='hostname of '+
'current machine.')
parser.add_argument('--disableoffline', dest='offline',
action='store_false',
help='disables offline time logging instead of queuing logged time')
@ -174,6 +157,8 @@ def parseArguments(argv):
help='defaults to ~/.wakatime.log')
parser.add_argument('--apiurl', dest='api_url',
help='heartbeats api url; for debugging with a local server')
parser.add_argument('--timeout', dest='timeout', type=int,
help='number of seconds to wait when sending heartbeats to api')
parser.add_argument('--config', dest='config',
help='defaults to ~/.wakatime.conf')
parser.add_argument('--verbose', dest='verbose', action='store_true',
@ -181,7 +166,7 @@ def parseArguments(argv):
parser.add_argument('--version', action='version', version=__version__)
# parse command line arguments
args = parser.parse_args(args=argv[1:])
args = parser.parse_args()
# use current unix epoch timestamp by default
if not args.timestamp:
@ -203,6 +188,13 @@ def parseArguments(argv):
args.key = default_key
else:
parser.error('Missing api key')
if not args.entity_type:
args.entity_type = 'file'
if not args.entity:
if args.file:
args.entity = args.file
else:
parser.error('argument --entity is required')
if not args.exclude:
args.exclude = []
if configs.has_option('settings', 'ignore'):
@ -210,14 +202,14 @@ def parseArguments(argv):
for pattern in configs.get('settings', 'ignore').split("\n"):
if pattern.strip() != '':
args.exclude.append(pattern)
except TypeError:
except TypeError: # pragma: nocover
pass
if configs.has_option('settings', 'exclude'):
try:
for pattern in configs.get('settings', 'exclude').split("\n"):
if pattern.strip() != '':
args.exclude.append(pattern)
except TypeError:
except TypeError: # pragma: nocover
pass
if not args.include:
args.include = []
@ -226,7 +218,7 @@ def parseArguments(argv):
for pattern in configs.get('settings', 'include').split("\n"):
if pattern.strip() != '':
args.include.append(pattern)
except TypeError:
except TypeError: # pragma: nocover
pass
if args.offline and configs.has_option('settings', 'offline'):
args.offline = configs.getboolean('settings', 'offline')
@ -242,37 +234,42 @@ def parseArguments(argv):
args.logfile = configs.get('settings', 'logfile')
if not args.api_url and configs.has_option('settings', 'api_url'):
args.api_url = configs.get('settings', 'api_url')
if not args.timeout and configs.has_option('settings', 'timeout'):
try:
args.timeout = int(configs.get('settings', 'timeout'))
except ValueError:
print(traceback.format_exc())
return args, configs
def should_exclude(fileName, include, exclude):
if fileName is not None and fileName.strip() != '':
def should_exclude(entity, include, exclude):
if entity is not None and entity.strip() != '':
try:
for pattern in include:
try:
compiled = re.compile(pattern, re.IGNORECASE)
if compiled.search(fileName):
if compiled.search(entity):
return False
except re.error as ex:
log.warning(u('Regex error ({msg}) for include pattern: {pattern}').format(
msg=u(ex),
pattern=u(pattern),
))
except TypeError:
except TypeError: # pragma: nocover
pass
try:
for pattern in exclude:
try:
compiled = re.compile(pattern, re.IGNORECASE)
if compiled.search(fileName):
if compiled.search(entity):
return pattern
except re.error as ex:
log.warning(u('Regex error ({msg}) for exclude pattern: {pattern}').format(
msg=u(ex),
pattern=u(pattern),
))
except TypeError:
except TypeError: # pragma: nocover
pass
return False
@ -297,31 +294,39 @@ def get_user_agent(plugin):
return user_agent
def send_heartbeat(project=None, branch=None, stats={}, key=None, targetFile=None,
timestamp=None, isWrite=None, plugin=None, offline=None, notfile=False,
hidefilenames=None, proxy=None, api_url=None, **kwargs):
def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None,
entity=None, timestamp=None, isWrite=None, plugin=None,
offline=None, entity_type='file', hidefilenames=None,
proxy=None, api_url=None, timeout=None, **kwargs):
"""Sends heartbeat as POST request to WakaTime api server.
Returns `SUCCESS` when heartbeat was sent, otherwise returns an
error code constant.
"""
if not api_url:
api_url = 'https://wakatime.com/api/v1/heartbeats'
api_url = 'https://api.wakatime.com/api/v1/heartbeats'
if not timeout:
timeout = 30
log.debug('Sending heartbeat to api at %s' % api_url)
data = {
'time': timestamp,
'file': targetFile,
'entity': entity,
'type': entity_type,
}
if hidefilenames and targetFile is not None and not notfile:
data['file'] = data['file'].rsplit('/', 1)[-1].rsplit('\\', 1)[-1]
if len(data['file'].strip('.').split('.', 1)) > 1:
data['file'] = u('HIDDEN.{ext}').format(ext=u(data['file'].strip('.').rsplit('.', 1)[-1]))
else:
data['file'] = u('HIDDEN')
if hidefilenames and entity is not None and entity_type == 'file':
extension = u(os.path.splitext(data['entity'])[1])
data['entity'] = u('HIDDEN{0}').format(extension)
if stats.get('lines'):
data['lines'] = stats['lines']
if stats.get('language'):
data['language'] = stats['language']
if stats.get('dependencies'):
data['dependencies'] = stats['dependencies']
if stats.get('lineno'):
data['lineno'] = stats['lineno']
if stats.get('cursorpos'):
data['cursorpos'] = stats['cursorpos']
if isWrite:
data['is_write'] = isWrite
if project:
@ -340,6 +345,8 @@ def send_heartbeat(project=None, branch=None, stats={}, key=None, targetFile=Non
'Accept': 'application/json',
'Authorization': auth,
}
if hostname:
headers['X-Machine-Name'] = u(hostname).encode('utf-8')
proxies = {}
if proxy:
proxies['https'] = proxy
@ -350,13 +357,16 @@ def send_heartbeat(project=None, branch=None, stats={}, key=None, targetFile=Non
except:
tz = None
if tz:
headers['TimeZone'] = u(tz.zone)
headers['TimeZone'] = u(tz.zone).encode('utf-8')
session_cache = SessionCache()
session = session_cache.get()
# log time to api
response = None
try:
response = requests.post(api_url, data=request_body, headers=headers,
proxies=proxies)
response = session.post(api_url, data=request_body, headers=headers,
proxies=proxies, timeout=timeout)
except RequestException:
exception_data = {
sys.exc_info()[0].__name__: u(sys.exc_info()[1]),
@ -371,102 +381,123 @@ def send_heartbeat(project=None, branch=None, stats={}, key=None, targetFile=Non
else:
log.error(exception_data)
else:
response_code = response.status_code if response is not None else None
response_content = response.text if response is not None else None
if response_code == 201:
code = response.status_code if response is not None else None
content = response.text if response is not None else None
if code == requests.codes.created or code == requests.codes.accepted:
log.debug({
'response_code': response_code,
'response_code': code,
})
return True
session_cache.save(session)
return SUCCESS
if offline:
if response_code != 400:
if code != 400:
queue = Queue()
queue.push(data, json.dumps(stats), plugin)
if response_code == 401:
if code == 401:
log.error({
'response_code': response_code,
'response_content': response_content,
'response_code': code,
'response_content': content,
})
session_cache.delete()
return AUTH_ERROR
elif log.isEnabledFor(logging.DEBUG):
log.warn({
'response_code': response_code,
'response_content': response_content,
'response_code': code,
'response_content': content,
})
else:
log.error({
'response_code': response_code,
'response_content': response_content,
'response_code': code,
'response_content': content,
})
else:
log.error({
'response_code': response_code,
'response_content': response_content,
'response_code': code,
'response_content': content,
})
return False
session_cache.delete()
return API_ERROR
def main(argv=None):
if not argv:
argv = sys.argv
def sync_offline_heartbeats(args, hostname):
"""Sends all heartbeats which were cached in the offline Queue."""
args, configs = parseArguments(argv)
queue = Queue()
while True:
heartbeat = queue.pop()
if heartbeat is None:
break
status = send_heartbeat(
project=heartbeat['project'],
entity=heartbeat['entity'],
timestamp=heartbeat['time'],
branch=heartbeat['branch'],
hostname=hostname,
stats=json.loads(heartbeat['stats']),
key=args.key,
isWrite=heartbeat['is_write'],
plugin=heartbeat['plugin'],
offline=args.offline,
hidefilenames=args.hidefilenames,
entity_type=heartbeat['type'],
proxy=args.proxy,
api_url=args.api_url,
timeout=args.timeout,
)
if status != SUCCESS:
if status == AUTH_ERROR:
return AUTH_ERROR
break
return SUCCESS
def execute(argv=None):
if argv:
sys.argv = ['wakatime'] + argv
args, configs = parseArguments()
if configs is None:
return 103 # config file parsing error
return CONFIG_FILE_PARSE_ERROR
setup_logging(args, __version__)
exclude = should_exclude(args.targetFile, args.include, args.exclude)
if exclude is not False:
log.debug(u('File not logged because matches exclude pattern: {pattern}').format(
pattern=u(exclude),
))
return 0
try:
exclude = should_exclude(args.entity, args.include, args.exclude)
if exclude is not False:
log.debug(u('Skipping because matches exclude pattern: {pattern}').format(
pattern=u(exclude),
))
return SUCCESS
if os.path.isfile(args.targetFile) or args.notfile:
if args.entity_type != 'file' or os.path.isfile(args.entity):
stats = get_file_stats(args.targetFile, notfile=args.notfile)
stats = get_file_stats(args.entity,
entity_type=args.entity_type,
lineno=args.lineno,
cursorpos=args.cursorpos)
project = None
if not args.notfile:
project = find_project(args.targetFile, configs=configs)
branch = None
project_name = args.project_name
if project:
branch = project.branch()
project_name = project.name()
project = args.project or args.alternate_project
branch = None
if args.entity_type == 'file':
project, branch = get_project_info(configs, args)
if send_heartbeat(
project=project_name,
branch=branch,
stats=stats,
**vars(args)
):
queue = Queue()
while True:
heartbeat = queue.pop()
if heartbeat is None:
break
sent = send_heartbeat(
project=heartbeat['project'],
targetFile=heartbeat['file'],
timestamp=heartbeat['time'],
branch=heartbeat['branch'],
stats=json.loads(heartbeat['stats']),
key=args.key,
isWrite=heartbeat['is_write'],
plugin=heartbeat['plugin'],
offline=args.offline,
hidefilenames=args.hidefilenames,
notfile=args.notfile,
proxy=args.proxy,
api_url=args.api_url,
)
if not sent:
break
return 0 # success
kwargs = vars(args)
kwargs['project'] = project
kwargs['branch'] = branch
kwargs['stats'] = stats
hostname = args.hostname or socket.gethostname()
kwargs['hostname'] = hostname
kwargs['timeout'] = args.timeout
return 102 # api error
status = send_heartbeat(**kwargs)
if status == SUCCESS:
return sync_offline_heartbeats(args, hostname)
else:
return status
else:
log.debug('File does not exist; ignoring this heartbeat.')
return 0
else:
log.debug('File does not exist; ignoring this heartbeat.')
return SUCCESS
except:
log.traceback()
return UNKNOWN_ERROR

View File

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
"""
wakatime.queue
~~~~~~~~~~~~~~
wakatime.offlinequeue
~~~~~~~~~~~~~~~~~~~~~
Queue for offline time logging.
http://wakatime.com
Queue for saving heartbeats while offline.
:copyright: (c) 2014 Alan Hamlett.
:license: BSD, see LICENSE for more details.
@ -19,21 +18,28 @@ from time import sleep
try:
import sqlite3
HAS_SQL = True
except ImportError:
except ImportError: # pragma: nocover
HAS_SQL = False
from .compat import u
log = logging.getLogger('WakaTime')
class Queue(object):
DB_FILE = os.path.join(os.path.expanduser('~'), '.wakatime.db')
db_file = os.path.join(os.path.expanduser('~'), '.wakatime.db')
table_name = 'heartbeat_1'
def get_db_file(self):
return self.db_file
def connect(self):
conn = sqlite3.connect(self.DB_FILE)
conn = sqlite3.connect(self.get_db_file())
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS heartbeat (
file text,
c.execute('''CREATE TABLE IF NOT EXISTS {0} (
entity text,
type text,
time real,
project text,
branch text,
@ -41,34 +47,33 @@ class Queue(object):
stats text,
misc text,
plugin text)
''')
'''.format(self.table_name))
return (conn, c)
def push(self, data, stats, plugin, misc=None):
if not HAS_SQL:
if not HAS_SQL: # pragma: nocover
return
try:
conn, c = self.connect()
heartbeat = {
'file': data.get('file'),
'entity': u(data.get('entity')),
'type': u(data.get('type')),
'time': data.get('time'),
'project': data.get('project'),
'branch': data.get('branch'),
'project': u(data.get('project')),
'branch': u(data.get('branch')),
'is_write': 1 if data.get('is_write') else 0,
'stats': stats,
'misc': misc,
'plugin': plugin,
'stats': u(stats),
'misc': u(misc),
'plugin': u(plugin),
}
c.execute('INSERT INTO heartbeat VALUES (:file,:time,:project,:branch,:is_write,:stats,:misc,:plugin)', heartbeat)
c.execute('INSERT INTO {0} VALUES (:entity,:type,:time,:project,:branch,:is_write,:stats,:misc,:plugin)'.format(self.table_name), heartbeat)
conn.commit()
conn.close()
except sqlite3.Error:
log.error(traceback.format_exc())
def pop(self):
if not HAS_SQL:
if not HAS_SQL: # pragma: nocover
return None
tries = 3
wait = 0.1
@ -82,42 +87,43 @@ class Queue(object):
while loop and tries > -1:
try:
c.execute('BEGIN IMMEDIATE')
c.execute('SELECT * FROM heartbeat LIMIT 1')
c.execute('SELECT * FROM {0} LIMIT 1'.format(self.table_name))
row = c.fetchone()
if row is not None:
values = []
clauses = []
index = 0
for row_name in ['file', 'time', 'project', 'branch', 'is_write']:
for row_name in ['entity', 'type', 'time', 'project', 'branch', 'is_write']:
if row[index] is not None:
clauses.append('{0}=?'.format(row_name))
values.append(row[index])
else:
else: # pragma: nocover
clauses.append('{0} IS NULL'.format(row_name))
index += 1
if len(values) > 0:
c.execute('DELETE FROM heartbeat WHERE {0}'.format(' AND '.join(clauses)), values)
else:
c.execute('DELETE FROM heartbeat WHERE {0}'.format(' AND '.join(clauses)))
c.execute('DELETE FROM {0} WHERE {1}'.format(self.table_name, ' AND '.join(clauses)), values)
else: # pragma: nocover
c.execute('DELETE FROM {0} WHERE {1}'.format(self.table_name, ' AND '.join(clauses)))
conn.commit()
if row is not None:
heartbeat = {
'file': row[0],
'time': row[1],
'project': row[2],
'branch': row[3],
'is_write': True if row[4] is 1 else False,
'stats': row[5],
'misc': row[6],
'plugin': row[7],
'entity': row[0],
'type': row[1],
'time': row[2],
'project': row[3],
'branch': row[4],
'is_write': True if row[5] is 1 else False,
'stats': row[6],
'misc': row[7],
'plugin': row[8],
}
loop = False
except sqlite3.Error:
except sqlite3.Error: # pragma: nocover
log.debug(traceback.format_exc())
sleep(wait)
tries -= 1
try:
conn.close()
except sqlite3.Error:
except sqlite3.Error: # pragma: nocover
log.debug(traceback.format_exc())
return heartbeat

View File

@ -61,7 +61,12 @@ considered public as object names -- the API of the formatter objects is
still considered an implementation detail.)
"""
__version__ = '1.2.1'
__version__ = '1.3.0' # we use our own version number independant of the
# one in stdlib and we release this on pypi.
__external_lib__ = True # to make sure the tests really test THIS lib,
# not the builtin one in Python stdlib
__all__ = [
'ArgumentParser',
'ArgumentError',
@ -1045,9 +1050,13 @@ class _SubParsersAction(Action):
class _ChoicesPseudoAction(Action):
def __init__(self, name, help):
def __init__(self, name, aliases, help):
metavar = dest = name
if aliases:
metavar += ' (%s)' % ', '.join(aliases)
sup = super(_SubParsersAction._ChoicesPseudoAction, self)
sup.__init__(option_strings=[], dest=name, help=help)
sup.__init__(option_strings=[], dest=dest, help=help,
metavar=metavar)
def __init__(self,
option_strings,
@ -1075,15 +1084,22 @@ class _SubParsersAction(Action):
if kwargs.get('prog') is None:
kwargs['prog'] = '%s %s' % (self._prog_prefix, name)
aliases = kwargs.pop('aliases', ())
# create a pseudo-action to hold the choice help
if 'help' in kwargs:
help = kwargs.pop('help')
choice_action = self._ChoicesPseudoAction(name, help)
choice_action = self._ChoicesPseudoAction(name, aliases, help)
self._choices_actions.append(choice_action)
# create the parser and add it to the map
parser = self._parser_class(**kwargs)
self._name_parser_map[name] = parser
# make parser available under aliases also
for alias in aliases:
self._name_parser_map[alias] = parser
return parser
def _get_subactions(self):

View File

@ -6,7 +6,7 @@
# /
"""
requests HTTP library
Requests HTTP library
~~~~~~~~~~~~~~~~~~~~~
Requests is an HTTP library, written in Python, for human beings. Basic GET
@ -42,11 +42,11 @@ is at <http://python-requests.org>.
"""
__title__ = 'requests'
__version__ = '2.6.0'
__build__ = 0x020503
__version__ = '2.9.1'
__build__ = 0x020901
__author__ = 'Kenneth Reitz'
__license__ = 'Apache 2.0'
__copyright__ = 'Copyright 2015 Kenneth Reitz'
__copyright__ = 'Copyright 2016 Kenneth Reitz'
# Attempt to enable urllib3's SNI support, if possible
try:
@ -62,7 +62,8 @@ from .sessions import session, Session
from .status_codes import codes
from .exceptions import (
RequestException, Timeout, URLRequired,
TooManyRedirects, HTTPError, ConnectionError
TooManyRedirects, HTTPError, ConnectionError,
FileModeWarning,
)
# Set default logging handler to avoid "No handler found" warnings.
@ -75,3 +76,8 @@ except ImportError:
pass
logging.getLogger(__name__).addHandler(NullHandler())
import warnings
# FileModeWarnings go off per the default.
warnings.simplefilter('default', FileModeWarning, append=True)

View File

@ -8,6 +8,7 @@ This module contains the transport adapters that Requests uses to define
and maintain connections.
"""
import os.path
import socket
from .models import Response
@ -17,11 +18,14 @@ from .packages.urllib3.util import Timeout as TimeoutSauce
from .packages.urllib3.util.retry import Retry
from .compat import urlparse, basestring
from .utils import (DEFAULT_CA_BUNDLE_PATH, get_encoding_from_headers,
prepend_scheme_if_needed, get_auth_from_url, urldefragauth)
prepend_scheme_if_needed, get_auth_from_url, urldefragauth,
select_proxy)
from .structures import CaseInsensitiveDict
from .packages.urllib3.exceptions import ClosedPoolError
from .packages.urllib3.exceptions import ConnectTimeoutError
from .packages.urllib3.exceptions import HTTPError as _HTTPError
from .packages.urllib3.exceptions import MaxRetryError
from .packages.urllib3.exceptions import NewConnectionError
from .packages.urllib3.exceptions import ProxyError as _ProxyError
from .packages.urllib3.exceptions import ProtocolError
from .packages.urllib3.exceptions import ReadTimeoutError
@ -35,6 +39,7 @@ from .auth import _basic_auth_str
DEFAULT_POOLBLOCK = False
DEFAULT_POOLSIZE = 10
DEFAULT_RETRIES = 0
DEFAULT_POOL_TIMEOUT = None
class BaseAdapter(object):
@ -60,7 +65,7 @@ class HTTPAdapter(BaseAdapter):
:param pool_connections: The number of urllib3 connection pools to cache.
:param pool_maxsize: The maximum number of connections to save in the pool.
:param int max_retries: The maximum number of retries each connection
:param max_retries: The maximum number of retries each connection
should attempt. Note, this applies only to failed DNS lookups, socket
connections and connection timeouts, never to requests where data has
made it to the server. By default, Requests does not retry failed
@ -103,7 +108,7 @@ class HTTPAdapter(BaseAdapter):
def __setstate__(self, state):
# Can't handle by adding 'proxy_manager' to self.__attrs__ because
# because self.poolmanager uses a lambda function, which isn't pickleable.
# self.poolmanager uses a lambda function, which isn't pickleable.
self.proxy_manager = {}
self.config = {}
@ -181,10 +186,15 @@ class HTTPAdapter(BaseAdapter):
raise Exception("Could not find a suitable SSL CA certificate bundle.")
conn.cert_reqs = 'CERT_REQUIRED'
conn.ca_certs = cert_loc
if not os.path.isdir(cert_loc):
conn.ca_certs = cert_loc
else:
conn.ca_cert_dir = cert_loc
else:
conn.cert_reqs = 'CERT_NONE'
conn.ca_certs = None
conn.ca_cert_dir = None
if cert:
if not isinstance(cert, basestring):
@ -237,8 +247,7 @@ class HTTPAdapter(BaseAdapter):
:param url: The URL to connect to.
:param proxies: (optional) A Requests-style dictionary of proxies used on this request.
"""
proxies = proxies or {}
proxy = proxies.get(urlparse(url.lower()).scheme)
proxy = select_proxy(url, proxies)
if proxy:
proxy = prepend_scheme_if_needed(proxy, 'http')
@ -271,12 +280,10 @@ class HTTPAdapter(BaseAdapter):
:class:`HTTPAdapter <requests.adapters.HTTPAdapter>`.
:param request: The :class:`PreparedRequest <PreparedRequest>` being sent.
:param proxies: A dictionary of schemes to proxy URLs.
:param proxies: A dictionary of schemes or schemes and hosts to proxy URLs.
"""
proxies = proxies or {}
proxy = select_proxy(request.url, proxies)
scheme = urlparse(request.url).scheme
proxy = proxies.get(scheme)
if proxy and scheme != 'https':
url = urldefragauth(request.url)
else:
@ -309,7 +316,6 @@ class HTTPAdapter(BaseAdapter):
:class:`HTTPAdapter <requests.adapters.HTTPAdapter>`.
:param proxies: The url of the proxy being used for this request.
:param kwargs: Optional additional keyword arguments.
"""
headers = {}
username, password = get_auth_from_url(proxy)
@ -326,8 +332,8 @@ class HTTPAdapter(BaseAdapter):
:param request: The :class:`PreparedRequest <PreparedRequest>` being sent.
:param stream: (optional) Whether to stream the request content.
:param timeout: (optional) How long to wait for the server to send
data before giving up, as a float, or a (`connect timeout, read
timeout <user/advanced.html#timeouts>`_) tuple.
data before giving up, as a float, or a :ref:`(connect timeout,
read timeout) <timeouts>` tuple.
:type timeout: float or tuple
:param verify: (optional) Whether to verify SSL certificates.
:param cert: (optional) Any user-provided SSL certificate to be trusted.
@ -375,7 +381,7 @@ class HTTPAdapter(BaseAdapter):
if hasattr(conn, 'proxy_pool'):
conn = conn.proxy_pool
low_conn = conn._get_conn(timeout=timeout)
low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT)
try:
low_conn.putrequest(request.method,
@ -394,7 +400,15 @@ class HTTPAdapter(BaseAdapter):
low_conn.send(b'\r\n')
low_conn.send(b'0\r\n\r\n')
r = low_conn.getresponse()
# Receive the response from the server
try:
# For Python 2.7+ versions, use buffering of HTTP
# responses
r = low_conn.getresponse(buffering=True)
except TypeError:
# For compatibility with Python 2.6 versions and back
r = low_conn.getresponse()
resp = HTTPResponse.from_httplib(
r,
pool=conn,
@ -407,22 +421,24 @@ class HTTPAdapter(BaseAdapter):
# Then, reraise so that we can handle the actual exception.
low_conn.close()
raise
else:
# All is well, return the connection to the pool.
conn._put_conn(low_conn)
except (ProtocolError, socket.error) as err:
raise ConnectionError(err, request=request)
except MaxRetryError as e:
if isinstance(e.reason, ConnectTimeoutError):
raise ConnectTimeout(e, request=request)
# TODO: Remove this in 3.0.0: see #2811
if not isinstance(e.reason, NewConnectionError):
raise ConnectTimeout(e, request=request)
if isinstance(e.reason, ResponseError):
raise RetryError(e, request=request)
raise ConnectionError(e, request=request)
except ClosedPoolError as e:
raise ConnectionError(e, request=request)
except _ProxyError as e:
raise ProxyError(e)

View File

@ -27,13 +27,13 @@ def request(method, url, **kwargs):
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': ('filename', fileobj)}``) for multipart encoding upload.
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
:param timeout: (optional) How long to wait for the server to send data
before giving up, as a float, or a (`connect timeout, read timeout
<user/advanced.html#timeouts>`_) tuple.
before giving up, as a float, or a :ref:`(connect timeout, read
timeout) <timeouts>` tuple.
:type timeout: float or tuple
:param allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE redirect following is allowed.
:type allow_redirects: bool
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
:param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided.
:param verify: (optional) whether the SSL cert will be verified. A CA_BUNDLE path can also be provided. Defaults to ``True``.
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
:param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair.
:return: :class:`Response <Response>` object
@ -46,26 +46,25 @@ def request(method, url, **kwargs):
<Response [200]>
"""
session = sessions.Session()
response = session.request(method=method, url=url, **kwargs)
# By explicitly closing the session, we avoid leaving sockets open which
# can trigger a ResourceWarning in some cases, and look like a memory leak
# in others.
session.close()
return response
# By using the 'with' statement we are sure the session is closed, thus we
# avoid leaving sockets open which can trigger a ResourceWarning in some
# cases, and look like a memory leak in others.
with sessions.Session() as session:
return session.request(method=method, url=url, **kwargs)
def get(url, **kwargs):
def get(url, params=None, **kwargs):
"""Sends a GET request.
:param url: URL for the new :class:`Request` object.
:param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`.
:param \*\*kwargs: Optional arguments that ``request`` takes.
:return: :class:`Response <Response>` object
:rtype: requests.Response
"""
kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs)
return request('get', url, params=params, **kwargs)
def options(url, **kwargs):

View File

@ -11,6 +11,7 @@ import os
import re
import time
import hashlib
import threading
from base64 import b64encode
@ -46,6 +47,15 @@ class HTTPBasicAuth(AuthBase):
self.username = username
self.password = password
def __eq__(self, other):
return all([
self.username == getattr(other, 'username', None),
self.password == getattr(other, 'password', None)
])
def __ne__(self, other):
return not self == other
def __call__(self, r):
r.headers['Authorization'] = _basic_auth_str(self.username, self.password)
return r
@ -63,19 +73,26 @@ class HTTPDigestAuth(AuthBase):
def __init__(self, username, password):
self.username = username
self.password = password
self.last_nonce = ''
self.nonce_count = 0
self.chal = {}
self.pos = None
self.num_401_calls = 1
# Keep state in per-thread local storage
self._thread_local = threading.local()
def init_per_thread_state(self):
# Ensure state is initialized just once per-thread
if not hasattr(self._thread_local, 'init'):
self._thread_local.init = True
self._thread_local.last_nonce = ''
self._thread_local.nonce_count = 0
self._thread_local.chal = {}
self._thread_local.pos = None
self._thread_local.num_401_calls = None
def build_digest_header(self, method, url):
realm = self.chal['realm']
nonce = self.chal['nonce']
qop = self.chal.get('qop')
algorithm = self.chal.get('algorithm')
opaque = self.chal.get('opaque')
realm = self._thread_local.chal['realm']
nonce = self._thread_local.chal['nonce']
qop = self._thread_local.chal.get('qop')
algorithm = self._thread_local.chal.get('algorithm')
opaque = self._thread_local.chal.get('opaque')
if algorithm is None:
_algorithm = 'MD5'
@ -103,7 +120,8 @@ class HTTPDigestAuth(AuthBase):
# XXX not implemented yet
entdig = None
p_parsed = urlparse(url)
path = p_parsed.path
#: path is request-uri defined in RFC 2616 which should not be empty
path = p_parsed.path or "/"
if p_parsed.query:
path += '?' + p_parsed.query
@ -113,12 +131,12 @@ class HTTPDigestAuth(AuthBase):
HA1 = hash_utf8(A1)
HA2 = hash_utf8(A2)
if nonce == self.last_nonce:
self.nonce_count += 1
if nonce == self._thread_local.last_nonce:
self._thread_local.nonce_count += 1
else:
self.nonce_count = 1
ncvalue = '%08x' % self.nonce_count
s = str(self.nonce_count).encode('utf-8')
self._thread_local.nonce_count = 1
ncvalue = '%08x' % self._thread_local.nonce_count
s = str(self._thread_local.nonce_count).encode('utf-8')
s += nonce.encode('utf-8')
s += time.ctime().encode('utf-8')
s += os.urandom(8)
@ -127,7 +145,7 @@ class HTTPDigestAuth(AuthBase):
if _algorithm == 'MD5-SESS':
HA1 = hash_utf8('%s:%s:%s' % (HA1, nonce, cnonce))
if qop is None:
if not qop:
respdig = KD(HA1, "%s:%s" % (nonce, HA2))
elif qop == 'auth' or 'auth' in qop.split(','):
noncebit = "%s:%s:%s:%s:%s" % (
@ -138,7 +156,7 @@ class HTTPDigestAuth(AuthBase):
# XXX handle auth-int.
return None
self.last_nonce = nonce
self._thread_local.last_nonce = nonce
# XXX should the partial digests be encoded too?
base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \
@ -157,28 +175,27 @@ class HTTPDigestAuth(AuthBase):
def handle_redirect(self, r, **kwargs):
"""Reset num_401_calls counter on redirects."""
if r.is_redirect:
self.num_401_calls = 1
self._thread_local.num_401_calls = 1
def handle_401(self, r, **kwargs):
"""Takes the given response and tries digest-auth, if needed."""
if self.pos is not None:
if self._thread_local.pos is not None:
# Rewind the file position indicator of the body to where
# it was to resend the request.
r.request.body.seek(self.pos)
num_401_calls = getattr(self, 'num_401_calls', 1)
r.request.body.seek(self._thread_local.pos)
s_auth = r.headers.get('www-authenticate', '')
if 'digest' in s_auth.lower() and num_401_calls < 2:
if 'digest' in s_auth.lower() and self._thread_local.num_401_calls < 2:
self.num_401_calls += 1
self._thread_local.num_401_calls += 1
pat = re.compile(r'digest ', flags=re.IGNORECASE)
self.chal = parse_dict_header(pat.sub('', s_auth, count=1))
self._thread_local.chal = parse_dict_header(pat.sub('', s_auth, count=1))
# Consume content and release the original connection
# to allow our new request to reuse the same one.
r.content
r.raw.release_conn()
r.close()
prep = r.request.copy()
extract_cookies_to_jar(prep._cookies, r.request, r.raw)
prep.prepare_cookies(prep._cookies)
@ -191,21 +208,34 @@ class HTTPDigestAuth(AuthBase):
return _r
self.num_401_calls = 1
self._thread_local.num_401_calls = 1
return r
def __call__(self, r):
# Initialize per-thread state, if needed
self.init_per_thread_state()
# If we have a saved nonce, skip the 401
if self.last_nonce:
if self._thread_local.last_nonce:
r.headers['Authorization'] = self.build_digest_header(r.method, r.url)
try:
self.pos = r.body.tell()
self._thread_local.pos = r.body.tell()
except AttributeError:
# In the case of HTTPDigestAuth being reused and the body of
# the previous request was a file-like object, pos has the
# file position of the previous body. Ensure it's set to
# None.
self.pos = None
self._thread_local.pos = None
r.register_hook('response', self.handle_401)
r.register_hook('response', self.handle_redirect)
self._thread_local.num_401_calls = 1
return r
def __eq__(self, other):
return all([
self.username == getattr(other, 'username', None),
self.password == getattr(other, 'password', None)
])
def __ne__(self, other):
return not self == other

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,9 @@ Compatibility code to be able to use `cookielib.CookieJar` with requests.
requests.utils imports from here, so be careful with imports.
"""
import copy
import time
import calendar
import collections
from .compat import cookielib, urlparse, urlunparse, Morsel
@ -142,10 +144,13 @@ def remove_cookie_by_name(cookiejar, name, domain=None, path=None):
"""
clearables = []
for cookie in cookiejar:
if cookie.name == name:
if domain is None or domain == cookie.domain:
if path is None or path == cookie.path:
clearables.append((cookie.domain, cookie.path, cookie.name))
if cookie.name != name:
continue
if domain is not None and domain != cookie.domain:
continue
if path is not None and path != cookie.path:
continue
clearables.append((cookie.domain, cookie.path, cookie.name))
for domain, path, name in clearables:
cookiejar.clear(domain, path, name)
@ -272,6 +277,12 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping):
dictionary[cookie.name] = cookie.value
return dictionary
def __contains__(self, name):
try:
return super(RequestsCookieJar, self).__contains__(name)
except CookieConflictError:
return True
def __getitem__(self, name):
"""Dict-like __getitem__() for compatibility with client code. Throws
exception if there are more than one cookie with name. In that case,
@ -302,7 +313,7 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping):
"""Updates this jar with cookies from another CookieJar or dict-like"""
if isinstance(other, cookielib.CookieJar):
for cookie in other:
self.set_cookie(cookie)
self.set_cookie(copy.copy(cookie))
else:
super(RequestsCookieJar, self).update(other)
@ -359,6 +370,21 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping):
return new_cj
def _copy_cookie_jar(jar):
if jar is None:
return None
if hasattr(jar, 'copy'):
# We're dealing with an instance of RequestsCookieJar
return jar.copy()
# We're dealing with a generic CookieJar instance
new_jar = copy.copy(jar)
new_jar.clear()
for cookie in jar:
new_jar.set_cookie(copy.copy(cookie))
return new_jar
def create_cookie(name, value, **kwargs):
"""Make a cookie from underspecified parameters.
@ -399,11 +425,15 @@ def morsel_to_cookie(morsel):
expires = None
if morsel['max-age']:
expires = time.time() + morsel['max-age']
try:
expires = int(time.time() + int(morsel['max-age']))
except ValueError:
raise TypeError('max-age: %s must be integer' % morsel['max-age'])
elif morsel['expires']:
time_template = '%a, %d-%b-%Y %H:%M:%S GMT'
expires = time.mktime(
time.strptime(morsel['expires'], time_template)) - time.timezone
expires = calendar.timegm(
time.strptime(morsel['expires'], time_template)
)
return create_cookie(
comment=morsel['comment'],
comment_url=bool(morsel['comment']),

View File

@ -97,3 +97,18 @@ class StreamConsumedError(RequestException, TypeError):
class RetryError(RequestException):
"""Custom retries logic failed"""
# Warnings
class RequestsWarning(Warning):
"""Base warning for Requests."""
pass
class FileModeWarning(RequestsWarning, DeprecationWarning):
"""
A file was opened in text mode, but Requests determined its binary length.
"""
pass

View File

@ -12,34 +12,23 @@ Available hooks:
The response generated from a Request.
"""
HOOKS = ['response']
def default_hooks():
hooks = {}
for event in HOOKS:
hooks[event] = []
return hooks
return dict((event, []) for event in HOOKS)
# TODO: response is the only one
def dispatch_hook(key, hooks, hook_data, **kwargs):
"""Dispatches a hook dictionary on a given piece of data."""
hooks = hooks or dict()
if key in hooks:
hooks = hooks.get(key)
hooks = hooks.get(key)
if hooks:
if hasattr(hooks, '__call__'):
hooks = [hooks]
for hook in hooks:
_hook_data = hook(hook_data, **kwargs)
if _hook_data is not None:
hook_data = _hook_data
return hook_data

View File

@ -15,7 +15,7 @@ from .hooks import default_hooks
from .structures import CaseInsensitiveDict
from .auth import HTTPBasicAuth
from .cookies import cookiejar_from_dict, get_cookie_header
from .cookies import cookiejar_from_dict, get_cookie_header, _copy_cookie_jar
from .packages.urllib3.fields import RequestField
from .packages.urllib3.filepost import encode_multipart_formdata
from .packages.urllib3.util import parse_url
@ -30,7 +30,8 @@ from .utils import (
iter_slices, guess_json_utf, super_len, to_native_string)
from .compat import (
cookielib, urlunparse, urlsplit, urlencode, str, bytes, StringIO,
is_py2, chardet, json, builtin_str, basestring)
is_py2, chardet, builtin_str, basestring)
from .compat import json as complexjson
from .status_codes import codes
#: The set of HTTP status codes that indicate an automatically
@ -42,12 +43,11 @@ REDIRECT_STATI = (
codes.temporary_redirect, # 307
codes.permanent_redirect, # 308
)
DEFAULT_REDIRECT_LIMIT = 30
CONTENT_CHUNK_SIZE = 10 * 1024
ITER_CHUNK_SIZE = 512
json_dumps = json.dumps
class RequestEncodingMixin(object):
@property
@ -149,8 +149,7 @@ class RequestEncodingMixin(object):
else:
fdata = fp.read()
rf = RequestField(name=k, data=fdata,
filename=fn, headers=fh)
rf = RequestField(name=k, data=fdata, filename=fn, headers=fh)
rf.make_multipart(content_type=ft)
new_fields.append(rf)
@ -193,7 +192,7 @@ class Request(RequestHooksMixin):
:param headers: dictionary of headers to send.
:param files: dictionary of {filename: fileobject} files to multipart upload.
:param data: the body to attach to the request. If a dictionary is provided, form-encoding will take place.
:param json: json for the body to attach to the request (if data is not specified).
:param json: json for the body to attach to the request (if files or data is not specified).
:param params: dictionary of URL parameters to append to the URL.
:param auth: Auth handler or (user, pass) tuple.
:param cookies: dictionary or CookieJar of cookies to attach to this request.
@ -207,17 +206,8 @@ class Request(RequestHooksMixin):
<PreparedRequest [GET]>
"""
def __init__(self,
method=None,
url=None,
headers=None,
files=None,
data=None,
params=None,
auth=None,
cookies=None,
hooks=None,
json=None):
def __init__(self, method=None, url=None, headers=None, files=None,
data=None, params=None, auth=None, cookies=None, hooks=None, json=None):
# Default empty dicts for dict params.
data = [] if data is None else data
@ -296,8 +286,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.hooks = default_hooks()
def prepare(self, method=None, url=None, headers=None, files=None,
data=None, params=None, auth=None, cookies=None, hooks=None,
json=None):
data=None, params=None, auth=None, cookies=None, hooks=None, json=None):
"""Prepares the entire request with the given parameters."""
self.prepare_method(method)
@ -306,6 +295,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.prepare_cookies(cookies)
self.prepare_body(data, files, json)
self.prepare_auth(auth, url)
# Note that prepare_auth must be last to enable authentication schemes
# such as OAuth to work on a fully prepared request.
@ -320,7 +310,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
p.method = self.method
p.url = self.url
p.headers = self.headers.copy() if self.headers is not None else None
p._cookies = self._cookies.copy() if self._cookies is not None else None
p._cookies = _copy_cookie_jar(self._cookies)
p.body = self.body
p.hooks = self.hooks
return p
@ -329,12 +319,12 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
"""Prepares the given HTTP method."""
self.method = method
if self.method is not None:
self.method = self.method.upper()
self.method = to_native_string(self.method.upper())
def prepare_url(self, url, params):
"""Prepares the given HTTP URL."""
#: Accept objects that have string representations.
#: We're unable to blindy call unicode/str functions
#: We're unable to blindly call unicode/str functions
#: as this will include the bytestring indicator (b'')
#: on python 3.x.
#: https://github.com/kennethreitz/requests/pull/2238
@ -357,8 +347,10 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
raise InvalidURL(*e.args)
if not scheme:
raise MissingSchema("Invalid URL {0!r}: No schema supplied. "
"Perhaps you meant http://{0}?".format(url))
error = ("Invalid URL {0!r}: No schema supplied. Perhaps you meant http://{0}?")
error = error.format(to_native_string(url, 'utf8'))
raise MissingSchema(error)
if not host:
raise InvalidURL("Invalid URL %r: No host supplied" % url)
@ -393,6 +385,9 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if isinstance(fragment, str):
fragment = fragment.encode('utf-8')
if isinstance(params, (str, bytes)):
params = to_native_string(params)
enc_params = self._encode_params(params)
if enc_params:
if query:
@ -422,9 +417,9 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
content_type = None
length = None
if json is not None:
if not data and json is not None:
content_type = 'application/json'
body = json_dumps(json)
body = complexjson.dumps(json)
is_stream = all([
hasattr(data, '__iter__'),
@ -442,7 +437,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if files:
raise NotImplementedError('Streamed bodies and files are mutually exclusive.')
if length is not None:
if length:
self.headers['Content-Length'] = builtin_str(length)
else:
self.headers['Transfer-Encoding'] = 'chunked'
@ -451,7 +446,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if files:
(body, content_type) = self._encode_files(files, data)
else:
if data and json is None:
if data:
body = self._encode_params(data)
if isinstance(data, basestring) or hasattr(data, 'read'):
content_type = None
@ -501,7 +496,15 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.prepare_content_length(self.body)
def prepare_cookies(self, cookies):
"""Prepares the given HTTP cookie data."""
"""Prepares the given HTTP cookie data.
This function eventually generates a ``Cookie`` header from the
given cookies using cookielib. Due to cookielib's design, the header
will not be regenerated if it already exists, meaning this function
can only be called once for the life of the
:class:`PreparedRequest <PreparedRequest>` object. Any subsequent calls
to ``prepare_cookies`` will have no actual effect, unless the "Cookie"
header is removed beforehand."""
if isinstance(cookies, cookielib.CookieJar):
self._cookies = cookies
@ -514,6 +517,10 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
def prepare_hooks(self, hooks):
"""Prepares the given hooks."""
# hooks can be passed as None to the prepare method and to this
# method. To prevent iterating over None, simply use an empty list
# if hooks is False-y
hooks = hooks or []
for event in hooks:
self.register_hook(event, hooks[event])
@ -524,16 +531,8 @@ class Response(object):
"""
__attrs__ = [
'_content',
'status_code',
'headers',
'url',
'history',
'encoding',
'reason',
'cookies',
'elapsed',
'request',
'_content', 'status_code', 'headers', 'url', 'history',
'encoding', 'reason', 'cookies', 'elapsed', 'request'
]
def __init__(self):
@ -635,7 +634,7 @@ class Response(object):
@property
def is_permanent_redirect(self):
"""True if this Response one of the permanant versions of redirect"""
"""True if this Response one of the permanent versions of redirect"""
return ('location' in self.headers and self.status_code in (codes.moved_permanently, codes.permanent_redirect))
@property
@ -653,9 +652,10 @@ class Response(object):
If decode_unicode is True, content will be decoded using the best
available encoding based on the response.
"""
def generate():
try:
# Special case for urllib3.
# Special case for urllib3.
if hasattr(self.raw, 'stream'):
try:
for chunk in self.raw.stream(chunk_size, decode_content=True):
yield chunk
@ -665,7 +665,7 @@ class Response(object):
raise ContentDecodingError(e)
except ReadTimeoutError as e:
raise ConnectionError(e)
except AttributeError:
else:
# Standard file-like object.
while True:
chunk = self.raw.read(chunk_size)
@ -796,14 +796,16 @@ class Response(object):
encoding = guess_json_utf(self.content)
if encoding is not None:
try:
return json.loads(self.content.decode(encoding), **kwargs)
return complexjson.loads(
self.content.decode(encoding), **kwargs
)
except UnicodeDecodeError:
# Wrong UTF codec detected; usually because it's not UTF-8
# but some other 8-bit codec. This is an RFC violation,
# and the server didn't bother to tell us what codec *was*
# used.
pass
return json.loads(self.text, **kwargs)
return complexjson.loads(self.text, **kwargs)
@property
def links(self):
@ -829,10 +831,10 @@ class Response(object):
http_error_msg = ''
if 400 <= self.status_code < 500:
http_error_msg = '%s Client Error: %s' % (self.status_code, self.reason)
http_error_msg = '%s Client Error: %s for url: %s' % (self.status_code, self.reason, self.url)
elif 500 <= self.status_code < 600:
http_error_msg = '%s Server Error: %s' % (self.status_code, self.reason)
http_error_msg = '%s Server Error: %s for url: %s' % (self.status_code, self.reason, self.url)
if http_error_msg:
raise HTTPError(http_error_msg, response=self)
@ -843,4 +845,7 @@ class Response(object):
*Note: Should not normally need to be called explicitly.*
"""
if not self._content_consumed:
return self.raw.close()
return self.raw.release_conn()

View File

@ -1,8 +1,11 @@
If you are planning to submit a pull request to requests with any changes in
this library do not go any further. These are independent libraries which we
vendor into requests. Any changes necessary to these libraries must be made in
this library do not go any further. These are independent libraries which we
vendor into requests. Any changes necessary to these libraries must be made in
them and submitted as separate pull requests to those libraries.
urllib3 pull requests go here: https://github.com/shazow/urllib3
chardet pull requests go here: https://github.com/chardet/chardet
See https://github.com/kennethreitz/requests/pull/1812#issuecomment-30854316
for the reasoning behind this.

View File

@ -1,107 +1,36 @@
"""
Copyright (c) Donald Stufft, pip, and individual contributors
'''
Debian and other distributions "unbundle" requests' vendored dependencies, and
rewrite all imports to use the global versions of ``urllib3`` and ``chardet``.
The problem with this is that not only requests itself imports those
dependencies, but third-party code outside of the distros' control too.
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
In reaction to these problems, the distro maintainers replaced
``requests.packages`` with a magical "stub module" that imports the correct
modules. The implementations were varying in quality and all had severe
problems. For example, a symlink (or hardlink) that links the correct modules
into place introduces problems regarding object identity, since you now have
two modules in `sys.modules` with the same API, but different identities::
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
requests.packages.urllib3 is not urllib3
With version ``2.5.2``, requests started to maintain its own stub, so that
distro-specific breakage would be reduced to a minimum, even though the whole
issue is not requests' fault in the first place. See
https://github.com/kennethreitz/requests/pull/2375 for the corresponding pull
request.
'''
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
import sys
try:
from . import urllib3
except ImportError:
import urllib3
sys.modules['%s.urllib3' % __name__] = urllib3
class VendorAlias(object):
def __init__(self, package_names):
self._package_names = package_names
self._vendor_name = __name__
self._vendor_pkg = self._vendor_name + "."
self._vendor_pkgs = [
self._vendor_pkg + name for name in self._package_names
]
def find_module(self, fullname, path=None):
if fullname.startswith(self._vendor_pkg):
return self
def load_module(self, name):
# Ensure that this only works for the vendored name
if not name.startswith(self._vendor_pkg):
raise ImportError(
"Cannot import %s, must be a subpackage of '%s'." % (
name, self._vendor_name,
)
)
if not (name == self._vendor_name or
any(name.startswith(pkg) for pkg in self._vendor_pkgs)):
raise ImportError(
"Cannot import %s, must be one of %s." % (
name, self._vendor_pkgs
)
)
# Check to see if we already have this item in sys.modules, if we do
# then simply return that.
if name in sys.modules:
return sys.modules[name]
# Check to see if we can import the vendor name
try:
# We do this dance here because we want to try and import this
# module without hitting a recursion error because of a bunch of
# VendorAlias instances on sys.meta_path
real_meta_path = sys.meta_path[:]
try:
sys.meta_path = [
m for m in sys.meta_path
if not isinstance(m, VendorAlias)
]
__import__(name)
module = sys.modules[name]
finally:
# Re-add any additions to sys.meta_path that were made while
# during the import we just did, otherwise things like
# requests.packages.urllib3.poolmanager will fail.
for m in sys.meta_path:
if m not in real_meta_path:
real_meta_path.append(m)
# Restore sys.meta_path with any new items.
sys.meta_path = real_meta_path
except ImportError:
# We can't import the vendor name, so we'll try to import the
# "real" name.
real_name = name[len(self._vendor_pkg):]
try:
__import__(real_name)
module = sys.modules[real_name]
except ImportError:
raise ImportError("No module named '%s'" % (name,))
# If we've gotten here we've found the module we're looking for, either
# as part of our vendored package, or as the real name, so we'll add
# it to sys.modules as the vendored name so that we don't have to do
# the lookup again.
sys.modules[name] = module
# Finally, return the loaded module
return module
sys.meta_path.append(VendorAlias(["urllib3", "chardet"]))
try:
from . import chardet
except ImportError:
import chardet
sys.modules['%s.chardet' % __name__] = chardet

View File

@ -2,10 +2,8 @@
urllib3 - Thread-safe connection pooling and re-using.
"""
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT'
__version__ = '1.10.2'
from __future__ import absolute_import
import warnings
from .connectionpool import (
HTTPConnectionPool,
@ -32,8 +30,30 @@ except ImportError:
def emit(self, record):
pass
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT'
__version__ = '1.13.1'
__all__ = (
'HTTPConnectionPool',
'HTTPSConnectionPool',
'PoolManager',
'ProxyManager',
'HTTPResponse',
'Retry',
'Timeout',
'add_stderr_logger',
'connection_from_url',
'disable_warnings',
'encode_multipart_formdata',
'get_host',
'make_headers',
'proxy_from_url',
)
logging.getLogger(__name__).addHandler(NullHandler())
def add_stderr_logger(level=logging.DEBUG):
"""
Helper for quickly adding a StreamHandler to the logger. Useful for
@ -55,9 +75,16 @@ def add_stderr_logger(level=logging.DEBUG):
del NullHandler
# Set security warning to always go off by default.
import warnings
warnings.simplefilter('always', exceptions.SecurityWarning)
# SecurityWarning's always go off by default.
warnings.simplefilter('always', exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter('default', exceptions.SubjectAltNameWarning)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter('default', exceptions.InsecurePlatformWarning,
append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter('default', exceptions.SNIMissingWarning)
def disable_warnings(category=exceptions.HTTPWarning):
"""

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
from collections import Mapping, MutableMapping
try:
from threading import RLock
@ -97,14 +98,7 @@ class RecentlyUsedContainer(MutableMapping):
return list(iterkeys(self._container))
_dict_setitem = dict.__setitem__
_dict_getitem = dict.__getitem__
_dict_delitem = dict.__delitem__
_dict_contains = dict.__contains__
_dict_setdefault = dict.setdefault
class HTTPHeaderDict(dict):
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
@ -139,7 +133,8 @@ class HTTPHeaderDict(dict):
"""
def __init__(self, headers=None, **kwargs):
dict.__init__(self)
super(HTTPHeaderDict, self).__init__()
self._container = {}
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
@ -149,38 +144,44 @@ class HTTPHeaderDict(dict):
self.extend(kwargs)
def __setitem__(self, key, val):
return _dict_setitem(self, key.lower(), (key, val))
self._container[key.lower()] = (key, val)
return self._container[key.lower()]
def __getitem__(self, key):
val = _dict_getitem(self, key.lower())
val = self._container[key.lower()]
return ', '.join(val[1:])
def __delitem__(self, key):
return _dict_delitem(self, key.lower())
del self._container[key.lower()]
def __contains__(self, key):
return _dict_contains(self, key.lower())
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k1, self[k1]) for k1 in self) == dict((k2, other[k2]) for k2 in other)
return (dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged()))
def __ne__(self, other):
return not self.__eq__(other)
values = MutableMapping.values
get = MutableMapping.get
update = MutableMapping.update
if not PY3: # Python 2
if not PY3: # Python 2
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues
__marker = object()
def __len__(self):
return len(self._container)
def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]
def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
@ -216,7 +217,7 @@ class HTTPHeaderDict(dict):
key_lower = key.lower()
new_vals = key, val
# Keep the common case aka no item present as fast as possible
vals = _dict_setdefault(self, key_lower, new_vals)
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
# new_vals was not inserted, as there was a previous one
if isinstance(vals, list):
@ -225,22 +226,22 @@ class HTTPHeaderDict(dict):
else:
# vals should be a tuple then, i.e. only one item so far
# Need to convert the tuple to list for further extension
_dict_setitem(self, key_lower, [vals[0], vals[1], val])
self._container[key_lower] = [vals[0], vals[1], val]
def extend(*args, **kwargs):
def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__
"""
if len(args) > 2:
raise TypeError("update() takes at most 2 positional "
"arguments ({} given)".format(len(args)))
elif not args:
raise TypeError("update() takes at least 1 argument (0 given)")
self = args[0]
other = args[1] if len(args) >= 2 else ()
if isinstance(other, Mapping):
if len(args) > 1:
raise TypeError("extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args)))
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
elif isinstance(other, Mapping):
for key in other:
self.add(key, other[key])
elif hasattr(other, "keys"):
@ -257,7 +258,7 @@ class HTTPHeaderDict(dict):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = _dict_getitem(self, key.lower())
vals = self._container[key.lower()]
except KeyError:
return []
else:
@ -276,11 +277,11 @@ class HTTPHeaderDict(dict):
def _copy_from(self, other):
for key in other:
val = _dict_getitem(other, key)
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
_dict_setitem(self, key, val)
self._container[key.lower()] = [key] + val
def copy(self):
clone = type(self)()
@ -290,31 +291,34 @@ class HTTPHeaderDict(dict):
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = _dict_getitem(self, key)
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val
def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = _dict_getitem(self, key)
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])
def items(self):
return list(self.iteritems())
@classmethod
def from_httplib(cls, message, duplicates=('set-cookie',)): # Python 2
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
ret = cls(message.items())
# ret now contains only the last header line for each duplicate.
# Importing with all duplicates would be nice, but this would
# mean to repeat most of the raw parsing already done, when the
# message object was created. Extracting only the headers of interest
# separately, the cookies, should be faster and requires less
# extra code.
for key in duplicates:
ret.discard(key)
for val in message.getheaders(key):
ret.add(key, val)
return ret
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
headers = []
for line in message.headers:
if line.startswith((' ', '\t')):
key, value = headers[-1]
headers[-1] = (key, value + '\r\n' + line.rstrip())
continue
key, value = line.split(':', 1)
headers.append((key, value.strip()))
return cls(headers)

View File

@ -1,23 +1,20 @@
from __future__ import absolute_import
import datetime
import os
import sys
import socket
from socket import timeout as SocketTimeout
from socket import error as SocketError, timeout as SocketTimeout
import warnings
from .packages import six
try: # Python 3
from http.client import HTTPConnection as _HTTPConnection, HTTPException
from http.client import HTTPConnection as _HTTPConnection
from http.client import HTTPException # noqa: unused in this module
except ImportError:
from httplib import HTTPConnection as _HTTPConnection, HTTPException
class DummyConnection(object):
"Used to detect a failed ConnectionCls import."
pass
from httplib import HTTPConnection as _HTTPConnection
from httplib import HTTPException # noqa: unused in this module
try: # Compiled with SSL?
HTTPSConnection = DummyConnection
import ssl
BaseSSLError = ssl.SSLError
except (ImportError, AttributeError): # Platform-specific: No SSL.
@ -36,9 +33,10 @@ except NameError: # Python 2:
from .exceptions import (
NewConnectionError,
ConnectTimeoutError,
SubjectAltNameWarning,
SystemTimeWarning,
SecurityWarning,
)
from .packages.ssl_match_hostname import match_hostname
@ -60,6 +58,11 @@ port_by_scheme = {
RECENT_DATE = datetime.date(2014, 1, 1)
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
class HTTPConnection(_HTTPConnection, object):
"""
Based on httplib.HTTPConnection but provides an extra constructor
@ -133,11 +136,15 @@ class HTTPConnection(_HTTPConnection, object):
conn = connection.create_connection(
(self.host, self.port), self.timeout, **extra_kw)
except SocketTimeout:
except SocketTimeout as e:
raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout))
except SocketError as e:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e)
return conn
def _prepare_conn(self, conn):
@ -185,19 +192,25 @@ class VerifiedHTTPSConnection(HTTPSConnection):
"""
cert_reqs = None
ca_certs = None
ca_cert_dir = None
ssl_version = None
assert_fingerprint = None
def set_cert(self, key_file=None, cert_file=None,
cert_reqs=None, ca_certs=None,
assert_hostname=None, assert_fingerprint=None):
assert_hostname=None, assert_fingerprint=None,
ca_cert_dir=None):
if (ca_certs or ca_cert_dir) and cert_reqs is None:
cert_reqs = 'CERT_REQUIRED'
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
self.ca_certs = ca_certs
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
def connect(self):
# Add certificate verification
@ -234,6 +247,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
self.sock = ssl_wrap_socket(conn, self.key_file, self.cert_file,
cert_reqs=resolved_cert_reqs,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
server_hostname=hostname,
ssl_version=resolved_ssl_version)
@ -245,18 +259,30 @@ class VerifiedHTTPSConnection(HTTPSConnection):
cert = self.sock.getpeercert()
if not cert.get('subjectAltName', ()):
warnings.warn((
'Certificate has no `subjectAltName`, falling back to check for a `commonName` for now. '
'This feature is being removed by major browsers and deprecated by RFC 2818. '
'(See https://github.com/shazow/urllib3/issues/497 for details.)'),
SecurityWarning
'Certificate for {0} has no `subjectAltName`, falling back to check for a '
'`commonName` for now. This feature is being removed by major browsers and '
'deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 '
'for details.)'.format(hostname)),
SubjectAltNameWarning
)
match_hostname(cert, self.assert_hostname or hostname)
self.is_verified = (resolved_cert_reqs == ssl.CERT_REQUIRED
or self.assert_fingerprint is not None)
# In case the hostname is an IPv6 address, strip the square
# brackets from it before using it to validate. This is because
# a certificate with an IPv6 address in it won't have square
# brackets around that address. Sadly, match_hostname won't do this
# for us: it expects the plain host part without any extra work
# that might have been done to make it palatable to httplib.
asserted_hostname = self.assert_hostname or hostname
asserted_hostname = asserted_hostname.strip('[]')
match_hostname(cert, asserted_hostname)
self.is_verified = (resolved_cert_reqs == ssl.CERT_REQUIRED or
self.assert_fingerprint is not None)
if ssl:
# Make a copy for testing.
UnverifiedHTTPSConnection = HTTPSConnection
HTTPSConnection = VerifiedHTTPSConnection
else:
HTTPSConnection = DummyConnection

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import errno
import logging
import sys
@ -10,13 +11,15 @@ try: # Python 3
from queue import LifoQueue, Empty, Full
except ImportError:
from Queue import LifoQueue, Empty, Full
import Queue as _ # Platform-specific: Windows
# Queue is imported for side effects on MS Windows
import Queue as _unused_module_Queue # noqa: unused
from .exceptions import (
ClosedPoolError,
ProtocolError,
EmptyPoolError,
HeaderParsingError,
HostChangedError,
LocationValueError,
MaxRetryError,
@ -25,6 +28,7 @@ from .exceptions import (
SSLError,
TimeoutError,
InsecureRequestWarning,
NewConnectionError,
)
from .packages.ssl_match_hostname import CertificateError
from .packages import six
@ -32,15 +36,16 @@ from .connection import (
port_by_scheme,
DummyConnection,
HTTPConnection, HTTPSConnection, VerifiedHTTPSConnection,
HTTPException, BaseSSLError, ConnectionError
HTTPException, BaseSSLError,
)
from .request import RequestMethods
from .response import HTTPResponse
from .util.connection import is_connection_dropped
from .util.response import assert_header_parsing
from .util.retry import Retry
from .util.timeout import Timeout
from .util.url import get_host
from .util.url import get_host, Url
xrange = six.moves.xrange
@ -50,7 +55,7 @@ log = logging.getLogger(__name__)
_Default = object()
## Pool objects
# Pool objects
class ConnectionPool(object):
"""
Base class for all connection pools, such as
@ -64,8 +69,7 @@ class ConnectionPool(object):
if not host:
raise LocationValueError("No host specified.")
# httplib doesn't like it when we include brackets in ipv6 addresses
self.host = host.strip('[]')
self.host = host
self.port = port
def __str__(self):
@ -120,7 +124,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:param maxsize:
Number of connections to save that can be reused. More than 1 is useful
in multithreaded situations. If ``block`` is set to false, more
in multithreaded situations. If ``block`` is set to False, more
connections will be created but they will not be saved once they've
been used.
@ -381,8 +385,19 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
log.debug("\"%s %s %s\" %s %s" % (method, url, http_version,
httplib_response.status,
httplib_response.length))
try:
assert_header_parsing(httplib_response.msg)
except HeaderParsingError as hpe: # Platform-specific: Python 3
log.warning(
'Failed to parse headers (url=%s): %s',
self._absolute_url(url), hpe, exc_info=True)
return httplib_response
def _absolute_url(self, path):
return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url
def close(self):
"""
Close all pooled connections and disable the pool.
@ -568,27 +583,24 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Close the connection. If a connection is reused on which there
# was a Certificate error, the next request will certainly raise
# another Certificate error.
if conn:
conn.close()
conn = None
conn = conn and conn.close()
release_conn = True
raise SSLError(e)
except SSLError:
# Treat SSLError separately from BaseSSLError to preserve
# traceback.
if conn:
conn.close()
conn = None
conn = conn and conn.close()
release_conn = True
raise
except (TimeoutError, HTTPException, SocketError, ConnectionError) as e:
if conn:
# Discard the connection for these exceptions. It will be
# be replaced during the next _get_conn() call.
conn.close()
conn = None
except (TimeoutError, HTTPException, SocketError, ProtocolError) as e:
# Discard the connection for these exceptions. It will be
# be replaced during the next _get_conn() call.
conn = conn and conn.close()
release_conn = True
if isinstance(e, SocketError) and self.proxy:
if isinstance(e, (SocketError, NewConnectionError)) and self.proxy:
e = ProxyError('Cannot connect to proxy.', e)
elif isinstance(e, (SocketError, HTTPException)):
e = ProtocolError('Connection aborted.', e)
@ -626,26 +638,31 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError:
if retries.raise_on_redirect:
# Release the connection for this response, since we're not
# returning it to be released manually.
response.release_conn()
raise
return response
log.info("Redirecting %s -> %s" % (url, redirect_location))
return self.urlopen(method, redirect_location, body, headers,
retries=retries, redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, **response_kw)
return self.urlopen(
method, redirect_location, body, headers,
retries=retries, redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, **response_kw)
# Check if we should retry the HTTP response.
if retries.is_forced_retry(method, status_code=response.status):
retries = retries.increment(method, url, response=response, _pool=self)
retries.sleep()
log.info("Forced retry: %s" % url)
return self.urlopen(method, url, body, headers,
retries=retries, redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, **response_kw)
return self.urlopen(
method, url, body, headers,
retries=retries, redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, **response_kw)
return response
@ -662,10 +679,10 @@ class HTTPSConnectionPool(HTTPConnectionPool):
``assert_hostname`` and ``host`` in this order to verify connections.
If ``assert_hostname`` is False, no verification is done.
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs`` and
``ssl_version`` are only used if :mod:`ssl` is available and are fed into
:meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket
into an SSL socket.
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``,
``ca_cert_dir``, and ``ssl_version`` are only used if :mod:`ssl` is
available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket.
"""
scheme = 'https'
@ -678,15 +695,20 @@ class HTTPSConnectionPool(HTTPConnectionPool):
key_file=None, cert_file=None, cert_reqs=None,
ca_certs=None, ssl_version=None,
assert_hostname=None, assert_fingerprint=None,
**conn_kw):
ca_cert_dir=None, **conn_kw):
HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize,
block, headers, retries, _proxy, _proxy_headers,
**conn_kw)
if ca_certs and cert_reqs is None:
cert_reqs = 'CERT_REQUIRED'
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
self.ca_certs = ca_certs
self.ca_cert_dir = ca_cert_dir
self.ssl_version = ssl_version
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
@ -702,6 +724,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
cert_file=self.cert_file,
cert_reqs=self.cert_reqs,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint)
conn.ssl_version = self.ssl_version
@ -735,7 +758,6 @@ class HTTPSConnectionPool(HTTPConnectionPool):
% (self.num_connections, self.host))
if not self.ConnectionCls or self.ConnectionCls is DummyConnection:
# Platform-specific: Python without ssl
raise SSLError("Can't connect to HTTPS URL because the SSL "
"module is not available.")

View File

@ -0,0 +1,223 @@
from __future__ import absolute_import
import logging
import os
import warnings
from ..exceptions import (
HTTPError,
HTTPWarning,
MaxRetryError,
ProtocolError,
TimeoutError,
SSLError
)
from ..packages.six import BytesIO
from ..request import RequestMethods
from ..response import HTTPResponse
from ..util.timeout import Timeout
from ..util.retry import Retry
try:
from google.appengine.api import urlfetch
except ImportError:
urlfetch = None
log = logging.getLogger(__name__)
class AppEnginePlatformWarning(HTTPWarning):
pass
class AppEnginePlatformError(HTTPError):
pass
class AppEngineManager(RequestMethods):
"""
Connection manager for Google App Engine sandbox applications.
This manager uses the URLFetch service directly instead of using the
emulated httplib, and is subject to URLFetch limitations as described in
the App Engine documentation here:
https://cloud.google.com/appengine/docs/python/urlfetch
Notably it will raise an AppEnginePlatformError if:
* URLFetch is not available.
* If you attempt to use this on GAEv2 (Managed VMs), as full socket
support is available.
* If a request size is more than 10 megabytes.
* If a response size is more than 32 megabtyes.
* If you use an unsupported request method such as OPTIONS.
Beyond those cases, it will raise normal urllib3 errors.
"""
def __init__(self, headers=None, retries=None, validate_certificate=True):
if not urlfetch:
raise AppEnginePlatformError(
"URLFetch is not available in this environment.")
if is_prod_appengine_mvms():
raise AppEnginePlatformError(
"Use normal urllib3.PoolManager instead of AppEngineManager"
"on Managed VMs, as using URLFetch is not necessary in "
"this environment.")
warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.org/en/latest/contrib.html.",
AppEnginePlatformWarning)
RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate
self.retries = retries or Retry.DEFAULT
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Return False to re-raise any potential exceptions
return False
def urlopen(self, method, url, body=None, headers=None,
retries=None, redirect=True, timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw):
retries = self._get_retries(retries, redirect)
try:
response = urlfetch.fetch(
url,
payload=body,
method=method,
headers=headers or {},
allow_truncated=False,
follow_redirects=(
redirect and
retries.redirect != 0 and
retries.total),
deadline=self._get_absolute_timeout(timeout),
validate_certificate=self.validate_certificate,
)
except urlfetch.DeadlineExceededError as e:
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if 'too large' in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.", e)
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if 'Too many redirects' in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.", e)
except urlfetch.SSLCertificateError as e:
raise SSLError(e)
except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e)
http_response = self._urlfetch_response_to_http_response(
response, **response_kw)
# Check for redirect response
if (http_response.get_redirect_location() and
retries.raise_on_redirect and redirect):
raise MaxRetryError(self, url, "too many redirects")
# Check if we should retry the HTTP response.
if retries.is_forced_retry(method, status_code=http_response.status):
retries = retries.increment(
method, url, response=http_response, _pool=self)
log.info("Forced retry: %s" % url)
retries.sleep()
return self.urlopen(
method, url,
body=body, headers=headers,
retries=retries, redirect=redirect,
timeout=timeout, **response_kw)
return http_response
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get('content-encoding')
if content_encoding == 'deflate':
del urlfetch_resp.headers['content-encoding']
return HTTPResponse(
# In order for decoding to work, we must present the content as
# a file-like object.
body=BytesIO(urlfetch_resp.content),
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
**response_kw
)
def _get_absolute_timeout(self, timeout):
if timeout is Timeout.DEFAULT_TIMEOUT:
return 5 # 5s is the default timeout for URLFetch.
if isinstance(timeout, Timeout):
if timeout.read is not timeout.connect:
warnings.warn(
"URLFetch does not support granular timeout settings, "
"reverting to total timeout.", AppEnginePlatformWarning)
return timeout.total
return timeout
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(
retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning)
return retries
def is_appengine():
return (is_local_appengine() or
is_prod_appengine() or
is_prod_appengine_mvms())
def is_appengine_sandbox():
return is_appengine() and not is_prod_appengine_mvms()
def is_local_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Development/' in os.environ['SERVER_SOFTWARE'])
def is_prod_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Google App Engine/' in os.environ['SERVER_SOFTWARE'] and
not is_prod_appengine_mvms())
def is_prod_appengine_mvms():
return os.environ.get('GAE_VM', False) == 'true'

View File

@ -3,6 +3,7 @@ NTLM authenticating pool, contributed by erikcederstran
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
"""
from __future__ import absolute_import
try:
from http.client import HTTPSConnection

View File

@ -38,13 +38,12 @@ Module Variables
----------------
:var DEFAULT_SSL_CIPHER_LIST: The list of supported SSL/TLS cipher suites.
Default: ``ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:
ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:!aNULL:!MD5:!DSS``
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
'''
from __future__ import absolute_import
try:
from ndg.httpsclient.ssl_peer_verification import SUBJ_ALT_NAME_SUPPORT
@ -55,7 +54,7 @@ except SyntaxError as e:
import OpenSSL.SSL
from pyasn1.codec.der import decoder as der_decoder
from pyasn1.type import univ, constraint
from socket import _fileobject, timeout
from socket import _fileobject, timeout, error as SocketError
import ssl
import select
@ -73,6 +72,12 @@ _openssl_versions = {
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
}
if hasattr(ssl, 'PROTOCOL_TLSv1_1') and hasattr(OpenSSL.SSL, 'TLSv1_1_METHOD'):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, 'PROTOCOL_TLSv1_2') and hasattr(OpenSSL.SSL, 'TLSv1_2_METHOD'):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
try:
_openssl_versions.update({ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD})
except AttributeError:
@ -81,27 +86,14 @@ except AttributeError:
_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
ssl.CERT_REQUIRED:
OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
# A secure default.
# Sources for more information on TLS ciphers:
#
# - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM over any AES-CBC for better performance and security,
# - use 3DES as fallback which is secure but slow,
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
DEFAULT_SSL_CIPHER_LIST = "ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:" + \
"ECDH+AES128:DH+AES:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:" + \
"!aNULL:!MD5:!DSS"
DEFAULT_SSL_CIPHER_LIST = util.ssl_.DEFAULT_CIPHERS
# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384
orig_util_HAS_SNI = util.HAS_SNI
orig_connection_ssl_wrap_socket = connection.ssl_wrap_socket
@ -121,7 +113,7 @@ def extract_from_urllib3():
util.HAS_SNI = orig_util_HAS_SNI
### Note: This is a slightly bug-fixed version of same from ndg-httpsclient.
# Note: This is a slightly bug-fixed version of same from ndg-httpsclient.
class SubjectAltName(BaseSubjectAltName):
'''ASN.1 implementation for subjectAltNames support'''
@ -132,7 +124,7 @@ class SubjectAltName(BaseSubjectAltName):
constraint.ValueSizeConstraint(1, 1024)
### Note: This is a slightly bug-fixed version of same from ndg-httpsclient.
# Note: This is a slightly bug-fixed version of same from ndg-httpsclient.
def get_subj_alt_name(peer_cert):
# Search through extensions
dns_name = []
@ -190,7 +182,7 @@ class WrappedSocket(object):
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
return b''
else:
raise
raise SocketError(e)
except OpenSSL.SSL.ZeroReturnError as e:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b''
@ -221,13 +213,21 @@ class WrappedSocket(object):
continue
def sendall(self, data):
while len(data):
sent = self._send_until_done(data)
data = data[sent:]
total_sent = 0
while total_sent < len(data):
sent = self._send_until_done(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self):
# FIXME rethrow compatible exceptions should we ever use this
self.connection.shutdown()
def close(self):
if self._makefile_refs < 1:
return self.connection.shutdown()
try:
return self.connection.close()
except OpenSSL.SSL.Error:
return
else:
self._makefile_refs -= 1
@ -268,7 +268,7 @@ def _verify_callback(cnx, x509, err_no, err_depth, return_code):
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
ca_certs=None, server_hostname=None,
ssl_version=None):
ssl_version=None, ca_cert_dir=None):
ctx = OpenSSL.SSL.Context(_openssl_versions[ssl_version])
if certfile:
keyfile = keyfile or certfile # Match behaviour of the normal python ssl library
@ -277,9 +277,9 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
ctx.use_privatekey_file(keyfile)
if cert_reqs != ssl.CERT_NONE:
ctx.set_verify(_openssl_verify[cert_reqs], _verify_callback)
if ca_certs:
if ca_certs or ca_cert_dir:
try:
ctx.load_verify_locations(ca_certs, None)
ctx.load_verify_locations(ca_certs, ca_cert_dir)
except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad ca_certs: %r' % ca_certs, e)
else:
@ -299,10 +299,12 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
try:
cnx.do_handshake()
except OpenSSL.SSL.WantReadError:
select.select([sock], [], [])
rd, _, _ = select.select([sock], [], [], sock.gettimeout())
if not rd:
raise timeout('select timed out')
continue
except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad handshake', e)
raise ssl.SSLError('bad handshake: %r' % e)
break
return WrappedSocket(cnx, sock)

View File

@ -1,16 +1,17 @@
from __future__ import absolute_import
# Base Exceptions
## Base Exceptions
class HTTPError(Exception):
"Base exception used by this module."
pass
class HTTPWarning(Warning):
"Base warning used by this module."
pass
class PoolError(HTTPError):
"Base exception for errors caused within a pool."
def __init__(self, pool, message):
@ -57,7 +58,7 @@ class ProtocolError(HTTPError):
ConnectionError = ProtocolError
## Leaf Exceptions
# Leaf Exceptions
class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded.
@ -113,6 +114,11 @@ class ConnectTimeoutError(TimeoutError):
pass
class NewConnectionError(ConnectTimeoutError, PoolError):
"Raised when we fail to establish a new connection. Usually ECONNREFUSED."
pass
class EmptyPoolError(PoolError):
"Raised when a pool runs out of connections and no more are allowed."
pass
@ -149,6 +155,11 @@ class SecurityWarning(HTTPWarning):
pass
class SubjectAltNameWarning(SecurityWarning):
"Warned when connecting to a host with a certificate missing a SAN."
pass
class InsecureRequestWarning(SecurityWarning):
"Warned when making an unverified HTTPS request."
pass
@ -162,3 +173,29 @@ class SystemTimeWarning(SecurityWarning):
class InsecurePlatformWarning(SecurityWarning):
"Warned when certain SSL configuration is not available on a platform."
pass
class SNIMissingWarning(HTTPWarning):
"Warned when making a HTTPS request without SNI available."
pass
class ResponseNotChunked(ProtocolError, ValueError):
"Response needs to be chunked in order to read it as chunks."
pass
class ProxySchemeUnknown(AssertionError, ValueError):
"ProxyManager does not support the supplied scheme"
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme):
message = "Not supported proxy scheme %s" % scheme
super(ProxySchemeUnknown, self).__init__(message)
class HeaderParsingError(HTTPError):
"Raised by assert_header_parsing, but we convert it to a log.warning statement."
def __init__(self, defects, unparsed_data):
message = '%s, unparsed data: %r' % (defects or 'Unknown', unparsed_data)
super(HeaderParsingError, self).__init__(message)

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import email.utils
import mimetypes

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import codecs
from uuid import uuid4

View File

@ -2,3 +2,4 @@ from __future__ import absolute_import
from . import ssl_match_hostname
__all__ = ('ssl_match_hostname', )

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import logging
try: # Python 3
@ -8,7 +9,7 @@ except ImportError:
from ._collections import RecentlyUsedContainer
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from .connectionpool import port_by_scheme
from .exceptions import LocationValueError, MaxRetryError
from .exceptions import LocationValueError, MaxRetryError, ProxySchemeUnknown
from .request import RequestMethods
from .util.url import parse_url
from .util.retry import Retry
@ -25,7 +26,7 @@ pool_classes_by_scheme = {
log = logging.getLogger(__name__)
SSL_KEYWORDS = ('key_file', 'cert_file', 'cert_reqs', 'ca_certs',
'ssl_version')
'ssl_version', 'ca_cert_dir')
class PoolManager(RequestMethods):
@ -227,8 +228,8 @@ class ProxyManager(PoolManager):
port = port_by_scheme.get(proxy.scheme, 80)
proxy = proxy._replace(port=port)
assert proxy.scheme in ("http", "https"), \
'Not supported proxy scheme %s' % proxy.scheme
if proxy.scheme not in ("http", "https"):
raise ProxySchemeUnknown(proxy.scheme)
self.proxy = proxy
self.proxy_headers = proxy_headers or {}

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
try:
from urllib.parse import urlencode
except ImportError:
@ -71,14 +72,22 @@ class RequestMethods(object):
headers=headers,
**urlopen_kw)
def request_encode_url(self, method, url, fields=None, **urlopen_kw):
def request_encode_url(self, method, url, fields=None, headers=None,
**urlopen_kw):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw = {'headers': headers}
extra_kw.update(urlopen_kw)
if fields:
url += '?' + urlencode(fields)
return self.urlopen(method, url, **urlopen_kw)
return self.urlopen(method, url, **extra_kw)
def request_encode_body(self, method, url, fields=None, headers=None,
encode_multipart=True, multipart_boundary=None,
@ -125,7 +134,8 @@ class RequestMethods(object):
if fields:
if 'body' in urlopen_kw:
raise TypeError('request got values for both \'fields\' and \'body\', can only specify one.')
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one.")
if encode_multipart:
body, content_type = encode_multipart_formdata(fields, boundary=multipart_boundary)

View File

@ -1,12 +1,18 @@
from __future__ import absolute_import
from contextlib import contextmanager
import zlib
import io
from socket import timeout as SocketTimeout
from socket import error as SocketError
from ._collections import HTTPHeaderDict
from .exceptions import ProtocolError, DecodeError, ReadTimeoutError
from .exceptions import (
ProtocolError, DecodeError, ReadTimeoutError, ResponseNotChunked
)
from .packages.six import string_types as basestring, binary_type, PY3
from .packages.six.moves import http_client as httplib
from .connection import HTTPException, BaseSSLError
from .util.response import is_fp_closed
from .util.response import is_fp_closed, is_response_to_head
class DeflateDecoder(object):
@ -117,6 +123,16 @@ class HTTPResponse(io.IOBase):
if hasattr(body, 'read'):
self._fp = body
# Are we using the chunked-style of transfer encoding?
self.chunked = False
self.chunk_left = None
tr_enc = self.headers.get('transfer-encoding', '').lower()
# Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings:
self.chunked = True
# If requested, preload the body.
if preload_content and not self._body:
self._body = self.read(decode_content=decode_content)
@ -157,6 +173,93 @@ class HTTPResponse(io.IOBase):
"""
return self._fp_bytes_read
def _init_decoder(self):
"""
Set-up the _decoder attribute if necessar.
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
if self._decoder is None and content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
def _decode(self, data, decode_content, flush_decoder):
"""
Decode the data passed in and potentially flush the decoder.
"""
try:
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
content_encoding = self.headers.get('content-encoding', '').lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e)
if flush_decoder and decode_content:
data += self._flush_decoder()
return data
def _flush_decoder(self):
"""
Flushes the decoder. Should only be called if the decoder is actually
being used.
"""
if self._decoder:
buf = self._decoder.decompress(b'')
return buf + self._decoder.flush()
return b''
@contextmanager
def _error_catcher(self):
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On exit, release the connection back to the pool.
"""
try:
try:
yield
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
except (HTTPException, SocketError) as e:
# This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e)
except Exception:
# The response may not be closed but we're not going to use it anymore
# so close it now to ensure that the connection is released back to the pool.
if self._original_response and not self._original_response.isclosed():
self._original_response.close()
# Closing the response may not actually be sufficient to close
# everything, so if we have a hold of the connection close that
# too.
if self._connection is not None:
self._connection.close()
raise
finally:
if self._original_response and self._original_response.isclosed():
self.release_conn()
def read(self, amt=None, decode_content=None, cache_content=False):
"""
Similar to :meth:`httplib.HTTPResponse.read`, but with two additional
@ -178,12 +281,7 @@ class HTTPResponse(io.IOBase):
after having ``.read()`` the file object. (Overridden if ``amt`` is
set.)
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
if self._decoder is None:
if content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
self._init_decoder()
if decode_content is None:
decode_content = self.decode_content
@ -191,67 +289,36 @@ class HTTPResponse(io.IOBase):
return
flush_decoder = False
data = None
try:
try:
if amt is None:
# cStringIO doesn't like amt=None
data = self._fp.read()
with self._error_catcher():
if amt is None:
# cStringIO doesn't like amt=None
data = self._fp.read()
flush_decoder = True
else:
cache_content = False
data = self._fp.read(amt)
if amt != 0 and not data: # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned
#
# This is redundant to what httplib/http.client _should_
# already do. However, versions of python released before
# December 15, 2012 (http://bugs.python.org/issue16298) do
# not properly close the connection in all cases. There is
# no harm in redundantly calling close.
self._fp.close()
flush_decoder = True
else:
cache_content = False
data = self._fp.read(amt)
if amt != 0 and not data: # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned
#
# This is redundant to what httplib/http.client _should_
# already do. However, versions of python released before
# December 15, 2012 (http://bugs.python.org/issue16298) do
# not properly close the connection in all cases. There is
# no harm in redundantly calling close.
self._fp.close()
flush_decoder = True
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
except HTTPException as e:
# This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e)
if data:
self._fp_bytes_read += len(data)
try:
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e)
if flush_decoder and decode_content and self._decoder:
buf = self._decoder.decompress(binary_type())
data += buf + self._decoder.flush()
data = self._decode(data, decode_content, flush_decoder)
if cache_content:
self._body = data
return data
finally:
if self._original_response and self._original_response.isclosed():
self.release_conn()
return data
def stream(self, amt=2**16, decode_content=None):
"""
@ -269,11 +336,15 @@ class HTTPResponse(io.IOBase):
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
while not is_fp_closed(self._fp):
data = self.read(amt=amt, decode_content=decode_content)
if self.chunked:
for line in self.read_chunked(amt, decode_content=decode_content):
yield line
else:
while not is_fp_closed(self._fp):
data = self.read(amt=amt, decode_content=decode_content)
if data:
yield data
if data:
yield data
@classmethod
def from_httplib(ResponseCls, r, **response_kw):
@ -285,10 +356,11 @@ class HTTPResponse(io.IOBase):
with ``original_response=r``.
"""
headers = r.msg
if not isinstance(headers, HTTPHeaderDict):
if PY3: # Python 3
if PY3: # Python 3
headers = HTTPHeaderDict(headers.items())
else: # Python 2
else: # Python 2
headers = HTTPHeaderDict.from_httplib(headers)
# HTTPResponse objects in Python 3 don't have a .strict attribute
@ -351,3 +423,92 @@ class HTTPResponse(io.IOBase):
else:
b[:len(temp)] = temp
return len(temp)
def _update_chunk_length(self):
# First, we'll figure out length of a chunk and then
# we'll try to read it from socket.
if self.chunk_left is not None:
return
line = self._fp.fp.readline()
line = line.split(b';', 1)[0]
try:
self.chunk_left = int(line, 16)
except ValueError:
# Invalid chunked protocol response, abort.
self.close()
raise httplib.IncompleteRead(line)
def _handle_chunk(self, amt):
returned_chunk = None
if amt is None:
chunk = self._fp._safe_read(self.chunk_left)
returned_chunk = chunk
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
elif amt < self.chunk_left:
value = self._fp._safe_read(amt)
self.chunk_left = self.chunk_left - amt
returned_chunk = value
elif amt == self.chunk_left:
value = self._fp._safe_read(amt)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
returned_chunk = value
else: # amt > self.chunk_left
returned_chunk = self._fp._safe_read(self.chunk_left)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
return returned_chunk
def read_chunked(self, amt=None, decode_content=None):
"""
Similar to :meth:`HTTPResponse.read`, but with an additional
parameter: ``decode_content``.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
self._init_decoder()
# FIXME: Rewrite this method and make it a class with a better structured logic.
if not self.chunked:
raise ResponseNotChunked(
"Response is not chunked. "
"Header 'transfer-encoding: chunked' is missing.")
# Don't bother reading the body of a HEAD request.
if self._original_response and is_response_to_head(self._original_response):
self._original_response.close()
return
with self._error_catcher():
while True:
self._update_chunk_length()
if self.chunk_left == 0:
break
chunk = self._handle_chunk(amt)
decoded = self._decode(chunk, decode_content=decode_content,
flush_decoder=False)
if decoded:
yield decoded
if decode_content:
# On CPython and PyPy, we should never need to flush the
# decoder. However, on Jython we *might* need to, so
# lets defensively do it anyway.
decoded = self._flush_decoder()
if decoded: # Platform-specific: Jython.
yield decoded
# Chunk content ends with \r\n: discard it.
while True:
line = self._fp.fp.readline()
if not line:
# Some sites may not end with '\r\n'.
break
if line == b'\r\n':
break
# We read everything; close the "file".
if self._original_response:
self._original_response.close()

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here.
from .connection import is_connection_dropped
from .request import make_headers
@ -22,3 +23,22 @@ from .url import (
split_first,
Url,
)
__all__ = (
'HAS_SNI',
'SSLContext',
'Retry',
'Timeout',
'Url',
'assert_fingerprint',
'current_time',
'is_connection_dropped',
'is_fp_closed',
'get_host',
'parse_url',
'make_headers',
'resolve_cert_reqs',
'resolve_ssl_version',
'split_first',
'ssl_wrap_socket',
)

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import socket
try:
from select import poll, POLLIN
@ -60,6 +61,8 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
"""
host, port = address
if host.startswith('['):
host = host.strip('[]')
err = None
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
@ -78,16 +81,16 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
sock.connect(sa)
return sock
except socket.error as _:
err = _
except socket.error as e:
err = e
if sock is not None:
sock.close()
sock = None
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")
raise socket.error("getaddrinfo returns an empty list")
def _set_socket_options(sock, options):

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
from base64 import b64encode
from ..packages.six import b

View File

@ -1,3 +1,9 @@
from __future__ import absolute_import
from ..packages.six.moves import http_client as httplib
from ..exceptions import HeaderParsingError
def is_fp_closed(obj):
"""
Checks whether a given file-like object is closed.
@ -20,3 +26,49 @@ def is_fp_closed(obj):
pass
raise ValueError("Unable to determine whether fp is closed.")
def assert_header_parsing(headers):
"""
Asserts whether all headers have been successfully parsed.
Extracts encountered errors from the result of parsing headers.
Only works on Python 3.
:param headers: Headers to verify.
:type headers: `httplib.HTTPMessage`.
:raises urllib3.exceptions.HeaderParsingError:
If parsing errors are found.
"""
# This will fail silently if we pass in the wrong kind of parameter.
# To make debugging easier add an explicit check.
if not isinstance(headers, httplib.HTTPMessage):
raise TypeError('expected httplib.Message, got {0}.'.format(
type(headers)))
defects = getattr(headers, 'defects', None)
get_payload = getattr(headers, 'get_payload', None)
unparsed_data = None
if get_payload: # Platform-specific: Python 3.
unparsed_data = get_payload()
if defects or unparsed_data:
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
def is_response_to_head(response):
"""
Checks, wether a the request of a response has been a HEAD-request.
Handles the quirks of AppEngine.
:param conn:
:type conn: :class:`httplib.HTTPResponse`
"""
# FIXME: Can we do this somehow without accessing private httplib _method?
method = response._method
if isinstance(method, int): # Platform-specific: Appengine
return method == 3
return method.upper() == 'HEAD'

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
import time
import logging
@ -94,7 +95,7 @@ class Retry(object):
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep
for [0.1s, 0.2s, 0.4s, ...] between retries. It will never be longer
than :attr:`Retry.MAX_BACKOFF`.
than :attr:`Retry.BACKOFF_MAX`.
By default, backoff is disabled (set to 0).
@ -126,7 +127,7 @@ class Retry(object):
self.method_whitelist = method_whitelist
self.backoff_factor = backoff_factor
self.raise_on_redirect = raise_on_redirect
self._observed_errors = _observed_errors # TODO: use .history instead?
self._observed_errors = _observed_errors # TODO: use .history instead?
def new(self, **kw):
params = dict(
@ -206,7 +207,8 @@ class Retry(object):
return min(retry_counts) < 0
def increment(self, method=None, url=None, response=None, error=None, _pool=None, _stacktrace=None):
def increment(self, method=None, url=None, response=None, error=None,
_pool=None, _stacktrace=None):
""" Return a new Retry object with incremented retry counters.
:param response: A response object, or None, if the server did not
@ -274,7 +276,6 @@ class Retry(object):
return new_retry
def __repr__(self):
return ('{cls.__name__}(total={self.total}, connect={self.connect}, '
'read={self.read}, redirect={self.redirect})').format(

View File

@ -1,18 +1,45 @@
from __future__ import absolute_import
import errno
import warnings
import hmac
from binascii import hexlify, unhexlify
from hashlib import md5, sha1, sha256
from ..exceptions import SSLError, InsecurePlatformWarning
from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning
SSLContext = None
HAS_SNI = False
create_default_context = None
import errno
import ssl
import warnings
# Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {
32: md5,
40: sha1,
64: sha256,
}
def _const_compare_digest_backport(a, b):
"""
Compare two digests of equal length in constant time.
The digests must be of type str/bytes.
Returns True if the digests match, and False otherwise.
"""
result = abs(len(a) - len(b))
for l, r in zip(bytearray(a), bytearray(b)):
result |= l ^ r
return result == 0
_const_compare_digest = getattr(hmac, 'compare_digest',
_const_compare_digest_backport)
try: # Test for SSL features
import ssl
from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23
from ssl import HAS_SNI # Has SNI?
except ImportError:
@ -25,14 +52,24 @@ except ImportError:
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
OP_NO_COMPRESSION = 0x20000
try:
from ssl import _DEFAULT_CIPHERS
except ImportError:
_DEFAULT_CIPHERS = (
'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:'
'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:'
'!eNULL:!MD5'
)
# A secure default.
# Sources for more information on TLS ciphers:
#
# - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM over any AES-CBC for better performance and security,
# - use 3DES as fallback which is secure but slow,
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
DEFAULT_CIPHERS = (
'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:'
'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:'
'!eNULL:!MD5'
)
try:
from ssl import SSLContext # Modern SSL?
@ -40,7 +77,8 @@ except ImportError:
import sys
class SSLContext(object): # Platform-specific: Python 2 & 3.1
supports_set_ciphers = sys.version_info >= (2, 7)
supports_set_ciphers = ((2, 7) <= sys.version_info < (3,) or
(3, 2) <= sys.version_info)
def __init__(self, protocol_version):
self.protocol = protocol_version
@ -57,8 +95,11 @@ except ImportError:
self.certfile = certfile
self.keyfile = keyfile
def load_verify_locations(self, location):
self.ca_certs = location
def load_verify_locations(self, cafile=None, capath=None):
self.ca_certs = cafile
if capath is not None:
raise SSLError("CA directories not supported in older Pythons")
def set_ciphers(self, cipher_suite):
if not self.supports_set_ciphers:
@ -101,31 +142,21 @@ def assert_fingerprint(cert, fingerprint):
Fingerprint as string of hexdigits, can be interspersed by colons.
"""
# Maps the length of a digest to a possible hash function producing
# this digest.
hashfunc_map = {
16: md5,
20: sha1,
32: sha256,
}
fingerprint = fingerprint.replace(':', '').lower()
digest_length, odd = divmod(len(fingerprint), 2)
if odd or digest_length not in hashfunc_map:
raise SSLError('Fingerprint is of invalid length.')
digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc:
raise SSLError(
'Fingerprint of invalid length: {0}'.format(fingerprint))
# We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode())
hashfunc = hashfunc_map[digest_length]
cert_digest = hashfunc(cert).digest()
if not cert_digest == fingerprint_bytes:
if not _const_compare_digest(cert_digest, fingerprint_bytes):
raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".'
.format(hexlify(fingerprint_bytes),
hexlify(cert_digest)))
.format(fingerprint, hexlify(cert_digest)))
def resolve_cert_reqs(candidate):
@ -167,7 +198,7 @@ def resolve_ssl_version(candidate):
return candidate
def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
def create_urllib3_context(ssl_version=None, cert_reqs=None,
options=None, ciphers=None):
"""All arguments have the same meaning as ``ssl_wrap_socket``.
@ -204,6 +235,9 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
"""
context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23)
# Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
if options is None:
options = 0
# SSLv2 is easily broken and is considered harmful and dangerous
@ -217,7 +251,7 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
context.options |= options
if getattr(context, 'supports_set_ciphers', True): # Platform-specific: Python 2.6
context.set_ciphers(ciphers or _DEFAULT_CIPHERS)
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
context.verify_mode = cert_reqs
if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2
@ -229,10 +263,11 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
ca_certs=None, server_hostname=None,
ssl_version=None, ciphers=None, ssl_context=None):
ssl_version=None, ciphers=None, ssl_context=None,
ca_cert_dir=None):
"""
All arguments except for server_hostname and ssl_context have the same
meaning as they do when using :func:`ssl.wrap_socket`.
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`.
:param server_hostname:
When SNI is supported, the expected hostname of the certificate
@ -242,15 +277,19 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
:param ciphers:
A string of ciphers we wish the client to support. This is not
supported on Python 2.6 as the ssl module does not support it.
:param ca_cert_dir:
A directory containing CA certificates in multiple separate files, as
supported by OpenSSL's -CApath flag or the capath argument to
SSLContext.load_verify_locations().
"""
context = ssl_context
if context is None:
context = create_urllib3_context(ssl_version, cert_reqs,
ciphers=ciphers)
if ca_certs:
if ca_certs or ca_cert_dir:
try:
context.load_verify_locations(ca_certs)
context.load_verify_locations(ca_certs, ca_cert_dir)
except IOError as e: # Platform-specific: Python 2.6, 2.7, 3.2
raise SSLError(e)
# Py33 raises FileNotFoundError which subclasses OSError
@ -259,8 +298,20 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
if e.errno == errno.ENOENT:
raise SSLError(e)
raise
if certfile:
context.load_cert_chain(certfile, keyfile)
if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI
return context.wrap_socket(sock, server_hostname=server_hostname)
warnings.warn(
'An HTTPS request has been made, but the SNI (Subject Name '
'Indication) extension to TLS is not available on this platform. '
'This may cause the server to present an incorrect TLS '
'certificate, which can cause validation failures. For more '
'information, see '
'https://urllib3.readthedocs.org/en/latest/security.html'
'#snimissingwarning.',
SNIMissingWarning
)
return context.wrap_socket(sock)

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
# The default socket timeout, used by httplib to indicate that no timeout was
# specified by the user
from socket import _GLOBAL_DEFAULT_TIMEOUT
@ -9,6 +10,7 @@ from ..exceptions import TimeoutStateError
# urllib3
_Default = object()
def current_time():
"""
Retrieve the current time. This function is mocked out in unit testing.
@ -226,9 +228,9 @@ class Timeout(object):
has not yet been called on this object.
"""
if (self.total is not None and
self.total is not self.DEFAULT_TIMEOUT and
self._read is not None and
self._read is not self.DEFAULT_TIMEOUT):
self.total is not self.DEFAULT_TIMEOUT and
self._read is not None and
self._read is not self.DEFAULT_TIMEOUT):
# In case the connect timeout has not yet been established.
if self._start_connect is None:
return self._read

View File

@ -1,3 +1,4 @@
from __future__ import absolute_import
from collections import namedtuple
from ..exceptions import LocationParseError
@ -15,6 +16,8 @@ class Url(namedtuple('Url', url_attrs)):
def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None,
query=None, fragment=None):
if path and not path.startswith('/'):
path = '/' + path
return super(Url, cls).__new__(cls, scheme, auth, host, port, path,
query, fragment)
@ -83,6 +86,7 @@ class Url(namedtuple('Url', url_attrs)):
def __str__(self):
return self.url
def split_first(s, delims):
"""
Given a string and an iterable of delimiters, split on the first found
@ -113,7 +117,7 @@ def split_first(s, delims):
if min_idx is None or min_idx < 0:
return s, '', None
return s[:min_idx], s[min_idx+1:], min_delim
return s[:min_idx], s[min_idx + 1:], min_delim
def parse_url(url):
@ -204,6 +208,7 @@ def parse_url(url):
return Url(scheme, auth, host, port, path, query, fragment)
def get_host(url):
"""
Deprecated. Use :func:`.parse_url` instead.

View File

@ -62,12 +62,11 @@ def merge_setting(request_setting, session_setting, dict_class=OrderedDict):
merged_setting = dict_class(to_key_val_list(session_setting))
merged_setting.update(to_key_val_list(request_setting))
# Remove keys that are set to None.
for (k, v) in request_setting.items():
if v is None:
del merged_setting[k]
merged_setting = dict((k, v) for (k, v) in merged_setting.items() if v is not None)
# Remove keys that are set to None. Extract keys first to avoid altering
# the dictionary during iteration.
none_keys = [k for (k, v) in merged_setting.items() if v is None]
for key in none_keys:
del merged_setting[key]
return merged_setting
@ -90,7 +89,7 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict):
class SessionRedirectMixin(object):
def resolve_redirects(self, resp, req, stream=False, timeout=None,
verify=True, cert=None, proxies=None):
verify=True, cert=None, proxies=None, **adapter_kwargs):
"""Receives a Response. Returns a generator of Responses."""
i = 0
@ -111,7 +110,7 @@ class SessionRedirectMixin(object):
resp.raw.read(decode_content=False)
if i >= self.max_redirects:
raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects)
raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp)
# Release the connection back into the pool.
resp.close()
@ -193,6 +192,7 @@ class SessionRedirectMixin(object):
cert=cert,
proxies=proxies,
allow_redirects=False,
**adapter_kwargs
)
extract_cookies_to_jar(self.cookies, prepared_request, resp.raw)
@ -273,7 +273,13 @@ class Session(SessionRedirectMixin):
>>> import requests
>>> s = requests.Session()
>>> s.get('http://httpbin.org/get')
200
<Response [200]>
Or as a context manager::
>>> with requests.Session() as s:
>>> s.get('http://httpbin.org/get')
<Response [200]>
"""
__attrs__ = [
@ -293,9 +299,9 @@ class Session(SessionRedirectMixin):
#: :class:`Request <Request>`.
self.auth = None
#: Dictionary mapping protocol to the URL of the proxy (e.g.
#: {'http': 'foo.bar:3128'}) to be used on each
#: :class:`Request <Request>`.
#: Dictionary mapping protocol or protocol and host to the URL of the proxy
#: (e.g. {'http': 'foo.bar:3128', 'http://host.name': 'foo.bar:4012'}) to
#: be used on each :class:`Request <Request>`.
self.proxies = {}
#: Event-handling hooks.
@ -319,7 +325,8 @@ class Session(SessionRedirectMixin):
#: limit, a :class:`TooManyRedirects` exception is raised.
self.max_redirects = DEFAULT_REDIRECT_LIMIT
#: Should we trust the environment?
#: Trust environment settings for proxy configuration, default
#: authentication and similar.
self.trust_env = True
#: A CookieJar containing all currently outstanding cookies set on this
@ -404,8 +411,8 @@ class Session(SessionRedirectMixin):
:param url: URL for the new :class:`Request` object.
:param params: (optional) Dictionary or bytes to be sent in the query
string for the :class:`Request`.
:param data: (optional) Dictionary or bytes to send in the body of the
:class:`Request`.
:param data: (optional) Dictionary, bytes, or file-like object to send
in the body of the :class:`Request`.
:param json: (optional) json to send in the body of the
:class:`Request`.
:param headers: (optional) Dictionary of HTTP Headers to send with the
@ -417,23 +424,20 @@ class Session(SessionRedirectMixin):
:param auth: (optional) Auth tuple or callable to enable
Basic/Digest/Custom HTTP Auth.
:param timeout: (optional) How long to wait for the server to send
data before giving up, as a float, or a (`connect timeout, read
timeout <user/advanced.html#timeouts>`_) tuple.
data before giving up, as a float, or a :ref:`(connect timeout,
read timeout) <timeouts>` tuple.
:type timeout: float or tuple
:param allow_redirects: (optional) Set to True by default.
:type allow_redirects: bool
:param proxies: (optional) Dictionary mapping protocol to the URL of
the proxy.
:param proxies: (optional) Dictionary mapping protocol or protocol and
hostname to the URL of the proxy.
:param stream: (optional) whether to immediately download the response
content. Defaults to ``False``.
:param verify: (optional) if ``True``, the SSL cert will be verified.
A CA_BUNDLE path can also be provided.
:param verify: (optional) whether the SSL cert will be verified.
A CA_BUNDLE path can also be provided. Defaults to ``True``.
:param cert: (optional) if String, path to ssl client cert file (.pem).
If Tuple, ('cert', 'key') pair.
"""
method = to_native_string(method)
# Create the Request.
req = Request(
method = method.upper(),
@ -549,23 +553,21 @@ class Session(SessionRedirectMixin):
if not isinstance(request, PreparedRequest):
raise ValueError('You can only send PreparedRequests.')
checked_urls = set()
while request.url in self.redirect_cache:
checked_urls.add(request.url)
new_url = self.redirect_cache.get(request.url)
if new_url in checked_urls:
break
request.url = new_url
# Set up variables needed for resolve_redirects and dispatching of hooks
allow_redirects = kwargs.pop('allow_redirects', True)
stream = kwargs.get('stream')
timeout = kwargs.get('timeout')
verify = kwargs.get('verify')
cert = kwargs.get('cert')
proxies = kwargs.get('proxies')
hooks = request.hooks
# Resolve URL in redirect cache, if available.
if allow_redirects:
checked_urls = set()
while request.url in self.redirect_cache:
checked_urls.add(request.url)
new_url = self.redirect_cache.get(request.url)
if new_url in checked_urls:
break
request.url = new_url
# Get the appropriate adapter to use
adapter = self.get_adapter(url=request.url)
@ -591,12 +593,7 @@ class Session(SessionRedirectMixin):
extract_cookies_to_jar(self.cookies, request, r.raw)
# Redirect resolving generator.
gen = self.resolve_redirects(r, request,
stream=stream,
timeout=timeout,
verify=verify,
cert=cert,
proxies=proxies)
gen = self.resolve_redirects(r, request, **kwargs)
# Resolve redirects if allowed.
history = [resp for resp in gen] if allow_redirects else []
@ -639,7 +636,7 @@ class Session(SessionRedirectMixin):
'cert': cert}
def get_adapter(self, url):
"""Returns the appropriate connnection adapter for the given URL."""
"""Returns the appropriate connection adapter for the given URL."""
for (prefix, adapter) in self.adapters.items():
if url.lower().startswith(prefix):

View File

@ -78,11 +78,12 @@ _codes = {
507: ('insufficient_storage',),
509: ('bandwidth_limit_exceeded', 'bandwidth'),
510: ('not_extended',),
511: ('network_authentication_required', 'network_auth', 'network_authentication'),
}
codes = LookupDict(name='status_codes')
for (code, titles) in list(_codes.items()):
for code, titles in _codes.items():
for title in titles:
setattr(codes, title, code)
if not title.startswith('\\'):

View File

@ -29,7 +29,7 @@ from .compat import (quote, urlparse, bytes, str, OrderedDict, unquote, is_py2,
basestring)
from .cookies import RequestsCookieJar, cookiejar_from_dict
from .structures import CaseInsensitiveDict
from .exceptions import InvalidURL
from .exceptions import InvalidURL, FileModeWarning
_hush_pyflakes = (RequestsCookieJar,)
@ -48,26 +48,47 @@ def dict_to_sequence(d):
def super_len(o):
total_length = 0
current_position = 0
if hasattr(o, '__len__'):
return len(o)
total_length = len(o)
if hasattr(o, 'len'):
return o.len
elif hasattr(o, 'len'):
total_length = o.len
if hasattr(o, 'fileno'):
elif hasattr(o, 'getvalue'):
# e.g. BytesIO, cStringIO.StringIO
total_length = len(o.getvalue())
elif hasattr(o, 'fileno'):
try:
fileno = o.fileno()
except io.UnsupportedOperation:
pass
else:
return os.fstat(fileno).st_size
total_length = os.fstat(fileno).st_size
if hasattr(o, 'getvalue'):
# e.g. BytesIO, cStringIO.StringIO
return len(o.getvalue())
# Having used fstat to determine the file length, we need to
# confirm that this file was opened up in binary mode.
if 'b' not in o.mode:
warnings.warn((
"Requests has determined the content-length for this "
"request using the binary size of the file: however, the "
"file has been opened in text mode (i.e. without the 'b' "
"flag in the mode). This may lead to an incorrect "
"content-length. In Requests 3.0, support will be removed "
"for files in text mode."),
FileModeWarning
)
if hasattr(o, 'tell'):
current_position = o.tell()
return max(0, total_length - current_position)
def get_netrc_auth(url):
def get_netrc_auth(url, raise_errors=False):
"""Returns the Requests tuple auth for a given url from netrc."""
try:
@ -94,8 +115,12 @@ def get_netrc_auth(url):
ri = urlparse(url)
# Strip port numbers from netloc
host = ri.netloc.split(':')[0]
# Strip port numbers from netloc. This weird `if...encode`` dance is
# used for Python 3.2, which doesn't support unicode literals.
splitstr = b':'
if isinstance(url, str):
splitstr = splitstr.decode('ascii')
host = ri.netloc.split(splitstr)[0]
try:
_netrc = netrc(netrc_path).authenticators(host)
@ -105,8 +130,9 @@ def get_netrc_auth(url):
return (_netrc[login_i], _netrc[2])
except (NetrcParseError, IOError):
# If there was a parsing error or a permissions issue reading the file,
# we'll just skip netrc auth
pass
# we'll just skip netrc auth unless explicitly asked to raise errors.
if raise_errors:
raise
# AppEngine hackiness.
except (ImportError, AttributeError):
@ -498,7 +524,9 @@ def should_bypass_proxies(url):
if no_proxy:
# We need to check whether we match here. We need to see if we match
# the end of the netloc, both with and without the port.
no_proxy = no_proxy.replace(' ', '').split(',')
no_proxy = (
host for host in no_proxy.replace(' ', '').split(',') if host
)
ip = netloc.split(':')[0]
if is_ipv4_address(ip):
@ -536,36 +564,22 @@ def get_environ_proxies(url):
else:
return getproxies()
def select_proxy(url, proxies):
"""Select a proxy for the url, if applicable.
:param url: The url being for the request
:param proxies: A dictionary of schemes or schemes and hosts to proxy URLs
"""
proxies = proxies or {}
urlparts = urlparse(url)
proxy = proxies.get(urlparts.scheme+'://'+urlparts.hostname)
if proxy is None:
proxy = proxies.get(urlparts.scheme)
return proxy
def default_user_agent(name="python-requests"):
"""Return a string representing the default user agent."""
_implementation = platform.python_implementation()
if _implementation == 'CPython':
_implementation_version = platform.python_version()
elif _implementation == 'PyPy':
_implementation_version = '%s.%s.%s' % (sys.pypy_version_info.major,
sys.pypy_version_info.minor,
sys.pypy_version_info.micro)
if sys.pypy_version_info.releaselevel != 'final':
_implementation_version = ''.join([_implementation_version, sys.pypy_version_info.releaselevel])
elif _implementation == 'Jython':
_implementation_version = platform.python_version() # Complete Guess
elif _implementation == 'IronPython':
_implementation_version = platform.python_version() # Complete Guess
else:
_implementation_version = 'Unknown'
try:
p_system = platform.system()
p_release = platform.release()
except IOError:
p_system = 'Unknown'
p_release = 'Unknown'
return " ".join(['%s/%s' % (name, __version__),
'%s/%s' % (_implementation, _implementation_version),
'%s/%s' % (p_system, p_release)])
return '%s/%s' % (name, __version__)
def default_headers():

View File

@ -5,9 +5,8 @@ interchange format.
:mod:`simplejson` exposes an API familiar to users of the standard library
:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained
version of the :mod:`json` library contained in Python 2.6, but maintains
compatibility with Python 2.4 and Python 2.5 and (currently) has
significant performance advantages, even without using the optional C
extension for speedups.
compatibility back to Python 2.5 and (currently) has significant performance
advantages, even without using the optional C extension for speedups.
Encoding basic Python object hierarchies::
@ -98,7 +97,7 @@ Using simplejson.tool from the shell to validate and pretty-print::
Expecting property name: line 1 column 3 (char 2)
"""
from __future__ import absolute_import
__version__ = '3.6.5'
__version__ = '3.8.0'
__all__ = [
'dump', 'dumps', 'load', 'loads',
'JSONDecoder', 'JSONDecodeError', 'JSONEncoder',
@ -140,6 +139,7 @@ _default_encoder = JSONEncoder(
use_decimal=True,
namedtuple_as_object=True,
tuple_as_array=True,
iterable_as_array=False,
bigint_as_string=False,
item_sort_key=None,
for_json=False,
@ -152,7 +152,8 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw):
for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` as a JSON formatted stream to ``fp`` (a
``.write()``-supporting file-like object).
@ -204,6 +205,10 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise. Note that this is still a
@ -242,7 +247,7 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
check_circular and allow_nan and
cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array
and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None
@ -258,6 +263,7 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
default=default, use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string,
sort_keys=sort_keys,
item_sort_key=item_sort_key,
@ -276,7 +282,8 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw):
for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` to a JSON formatted ``str``.
If ``skipkeys`` is false then ``dict`` keys that are not basic types
@ -324,6 +331,10 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise.
@ -356,12 +367,11 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
"""
# cached encoder
if (
not skipkeys and ensure_ascii and
if (not skipkeys and ensure_ascii and
check_circular and allow_nan and
cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array
and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None
@ -377,6 +387,7 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string,
sort_keys=sort_keys,
item_sort_key=item_sort_key,

View File

@ -10,6 +10,7 @@
#define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_FromStringAndSize PyBytes_FromStringAndSize
#define PyInt_Check(obj) 0
#define PyInt_CheckExact(obj) 0
#define JSON_UNICHR Py_UCS4
#define JSON_InternFromString PyUnicode_InternFromString
#define JSON_Intern_GET_SIZE PyUnicode_GET_SIZE
@ -168,6 +169,7 @@ typedef struct _PyEncoderObject {
int use_decimal;
int namedtuple_as_object;
int tuple_as_array;
int iterable_as_array;
PyObject *max_long_size;
PyObject *min_long_size;
PyObject *item_sort_key;
@ -660,7 +662,20 @@ encoder_stringify_key(PyEncoderObject *s, PyObject *key)
return _encoded_const(key);
}
else if (PyInt_Check(key) || PyLong_Check(key)) {
return PyObject_Str(key);
if (!(PyInt_CheckExact(key) || PyLong_CheckExact(key))) {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, key, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Str(tmp);
Py_DECREF(tmp);
return res;
}
else {
return PyObject_Str(key);
}
}
else if (s->use_decimal && PyObject_TypeCheck(key, (PyTypeObject *)s->Decimal)) {
return PyObject_Str(key);
@ -2567,7 +2582,6 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
static int
encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
{
/* initialize Encoder object */
static char *kwlist[] = {
"markers",
"default",
@ -2582,30 +2596,32 @@ encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
"use_decimal",
"namedtuple_as_object",
"tuple_as_array",
"iterable_as_array"
"int_as_string_bitcount",
"item_sort_key",
"encoding",
"for_json",
"ignore_nan",
"Decimal",
"iterable_as_array",
NULL};
PyEncoderObject *s;
PyObject *markers, *defaultfn, *encoder, *indent, *key_separator;
PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo;
PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array;
PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array, *iterable_as_array;
PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json;
PyObject *ignore_nan, *Decimal;
assert(PyEncoder_Check(self));
s = (PyEncoderObject *)self;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOO:make_encoder", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOOO:make_encoder", kwlist,
&markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator,
&sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal,
&namedtuple_as_object, &tuple_as_array,
&int_as_string_bitcount, &item_sort_key, &encoding, &for_json,
&ignore_nan, &Decimal))
&ignore_nan, &Decimal, &iterable_as_array))
return -1;
Py_INCREF(markers);
@ -2635,9 +2651,10 @@ encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
s->use_decimal = PyObject_IsTrue(use_decimal);
s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object);
s->tuple_as_array = PyObject_IsTrue(tuple_as_array);
s->iterable_as_array = PyObject_IsTrue(iterable_as_array);
if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) {
static const unsigned int long_long_bitsize = SIZEOF_LONG_LONG * 8;
int int_as_string_bitcount_val = PyLong_AsLong(int_as_string_bitcount);
int int_as_string_bitcount_val = (int)PyLong_AsLong(int_as_string_bitcount);
if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < long_long_bitsize) {
s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << int_as_string_bitcount_val);
s->min_long_size = PyLong_FromLongLong(-1LL << int_as_string_bitcount_val);
@ -2800,7 +2817,20 @@ encoder_encode_float(PyEncoderObject *s, PyObject *obj)
}
}
/* Use a better float format here? */
return PyObject_Repr(obj);
if (PyFloat_CheckExact(obj)) {
return PyObject_Repr(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyFloat_Type, obj, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Repr(tmp);
Py_DECREF(tmp);
return res;
}
}
static PyObject *
@ -2840,7 +2870,21 @@ encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ss
rv = _steal_accumulate(rval, encoded);
}
else if (PyInt_Check(obj) || PyLong_Check(obj)) {
PyObject *encoded = PyObject_Str(obj);
PyObject *encoded;
if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj)) {
encoded = PyObject_Str(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, obj, NULL);
if (tmp == NULL) {
encoded = NULL;
}
else {
encoded = PyObject_Str(tmp);
Py_DECREF(tmp);
}
}
if (encoded != NULL) {
encoded = maybe_quote_bigint(s, encoded, obj);
if (encoded == NULL)
@ -2895,6 +2939,16 @@ encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ss
else {
PyObject *ident = NULL;
PyObject *newobj;
if (s->iterable_as_array) {
newobj = PyObject_GetIter(obj);
if (newobj == NULL)
PyErr_Clear();
else {
rv = encoder_listencode_list(s, rval, newobj, indent_level);
Py_DECREF(newobj);
break;
}
}
if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(obj);

View File

@ -3,7 +3,8 @@
from __future__ import absolute_import
import re
from operator import itemgetter
from decimal import Decimal
# Do not import Decimal directly to avoid reload issues
import decimal
from .compat import u, unichr, binary_type, string_types, integer_types, PY3
def _import_speedups():
try:
@ -123,7 +124,7 @@ class JSONEncoder(object):
use_decimal=True, namedtuple_as_object=True,
tuple_as_array=True, bigint_as_string=False,
item_sort_key=None, for_json=False, ignore_nan=False,
int_as_string_bitcount=None):
int_as_string_bitcount=None, iterable_as_array=False):
"""Constructor for JSONEncoder, with sensible defaults.
If skipkeys is false, then it is a TypeError to attempt
@ -178,6 +179,10 @@ class JSONEncoder(object):
If tuple_as_array is true (the default), tuple (and subclasses) will
be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If bigint_as_string is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise.
@ -209,6 +214,7 @@ class JSONEncoder(object):
self.use_decimal = use_decimal
self.namedtuple_as_object = namedtuple_as_object
self.tuple_as_array = tuple_as_array
self.iterable_as_array = iterable_as_array
self.bigint_as_string = bigint_as_string
self.item_sort_key = item_sort_key
self.for_json = for_json
@ -311,6 +317,9 @@ class JSONEncoder(object):
elif o == _neginf:
text = '-Infinity'
else:
if type(o) != float:
# See #118, do not trust custom str/repr
o = float(o)
return _repr(o)
if ignore_nan:
@ -334,7 +343,7 @@ class JSONEncoder(object):
self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json,
self.ignore_nan, Decimal)
self.ignore_nan, decimal.Decimal, self.iterable_as_array)
else:
_iterencode = _make_iterencode(
markers, self.default, _encoder, self.indent, floatstr,
@ -343,7 +352,7 @@ class JSONEncoder(object):
self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json,
Decimal=Decimal)
self.iterable_as_array, Decimal=decimal.Decimal)
try:
return _iterencode(o, 0)
finally:
@ -382,11 +391,12 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
_use_decimal, _namedtuple_as_object, _tuple_as_array,
_int_as_string_bitcount, _item_sort_key,
_encoding,_for_json,
_iterable_as_array,
## HACK: hand-optimized bytecode; turn globals into locals
_PY3=PY3,
ValueError=ValueError,
string_types=string_types,
Decimal=Decimal,
Decimal=None,
dict=dict,
float=float,
id=id,
@ -395,7 +405,10 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
list=list,
str=str,
tuple=tuple,
iter=iter,
):
if _use_decimal and Decimal is None:
Decimal = decimal.Decimal
if _item_sort_key and not callable(_item_sort_key):
raise TypeError("item_sort_key must be None or callable")
elif _sort_keys and not _item_sort_key:
@ -412,6 +425,9 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
or
_int_as_string_bitcount < 1
)
if type(value) not in integer_types:
# See #118, do not trust custom str/repr
value = int(value)
if (
skip_quoting or
(-1 << _int_as_string_bitcount)
@ -501,6 +517,9 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
elif key is None:
key = 'null'
elif isinstance(key, integer_types):
if type(key) not in integer_types:
# See #118, do not trust custom str/repr
key = int(key)
key = str(key)
elif _use_decimal and isinstance(key, Decimal):
key = str(key)
@ -634,6 +653,16 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
elif _use_decimal and isinstance(o, Decimal):
yield str(o)
else:
while _iterable_as_array:
# Markers are not checked here because it is valid for
# an iterable to return self.
try:
o = iter(o)
except TypeError:
break
for chunk in _iterencode_list(o, _current_indent_level):
yield chunk
return
if markers is not None:
markerid = id(o)
if markerid in markers:

View File

@ -62,6 +62,7 @@ def all_tests_suite():
'simplejson.tests.test_namedtuple',
'simplejson.tests.test_tool',
'simplejson.tests.test_for_json',
'simplejson.tests.test_subclass',
]))
suite = get_suite()
import simplejson

View File

@ -0,0 +1,31 @@
import unittest
from StringIO import StringIO
import simplejson as json
def iter_dumps(obj, **kw):
return ''.join(json.JSONEncoder(**kw).iterencode(obj))
def sio_dump(obj, **kw):
sio = StringIO()
json.dumps(obj, **kw)
return sio.getvalue()
class TestIterable(unittest.TestCase):
def test_iterable(self):
l = [1, 2, 3]
for dumps in (json.dumps, iter_dumps, sio_dump):
expect = dumps(l)
default_expect = dumps(sum(l))
# Default is False
self.assertRaises(TypeError, dumps, iter(l))
self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False)
self.assertEqual(expect, dumps(iter(l), iterable_as_array=True))
# Ensure that the "default" gets called
self.assertEqual(default_expect, dumps(iter(l), default=sum))
self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum))
# Ensure that the "default" does not get called
self.assertEqual(
default_expect,
dumps(iter(l), iterable_as_array=True, default=sum))

View File

@ -0,0 +1,37 @@
from unittest import TestCase
import simplejson as json
from decimal import Decimal
class AlternateInt(int):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
class AlternateFloat(float):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
# class AlternateDecimal(Decimal):
# def __repr__(self):
# return 'invalid json'
class TestSubclass(TestCase):
def test_int(self):
self.assertEqual(json.dumps(AlternateInt(1)), '1')
self.assertEqual(json.dumps(AlternateInt(-1)), '-1')
self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1})
def test_float(self):
self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0')
self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0')
self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1})
# NOTE: Decimal subclasses are not supported as-is
# def test_decimal(self):
# self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0')
# self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0')

View File

@ -45,7 +45,3 @@ class TestTuples(unittest.TestCase):
self.assertEqual(
json.dumps(repr(t)),
sio.getvalue())
class TestNamedTuple(unittest.TestCase):
def test_namedtuple_dump(self):
pass

View File

@ -15,30 +15,70 @@ from .projects.git import Git
from .projects.mercurial import Mercurial
from .projects.projectmap import ProjectMap
from .projects.subversion import Subversion
from .projects.wakatime import WakaTime
from .projects.wakatime_project_file import WakaTimeProjectFile
log = logging.getLogger('WakaTime')
# List of plugin classes to find a project for the current file path.
# Project plugins will be processed with priority in the order below.
PLUGINS = [
WakaTime,
CONFIG_PLUGINS = [
WakaTimeProjectFile,
ProjectMap,
]
REV_CONTROL_PLUGINS = [
Git,
Mercurial,
Subversion,
]
def find_project(path, configs=None):
for plugin in PLUGINS:
plugin_name = plugin.__name__.lower()
plugin_configs = None
if configs and configs.has_section(plugin_name):
plugin_configs = dict(configs.items(plugin_name))
project = plugin(path, configs=plugin_configs)
def get_project_info(configs, args):
"""Find the current project and branch.
First looks for a .wakatime-project file. Second, uses the --project arg.
Third, uses the folder name from a revision control repository. Last, uses
the --alternate-project arg.
Returns a project, branch tuple.
"""
project_name, branch_name = None, None
for plugin_cls in CONFIG_PLUGINS:
plugin_name = plugin_cls.__name__.lower()
plugin_configs = get_configs_for_plugin(plugin_name, configs)
project = plugin_cls(args.entity, configs=plugin_configs)
if project.process():
return project
project_name = project_name or project.name()
branch_name = project.branch()
break
if project_name is None:
project_name = args.project
if project_name is None or branch_name is None:
for plugin_cls in REV_CONTROL_PLUGINS:
plugin_name = plugin_cls.__name__.lower()
plugin_configs = get_configs_for_plugin(plugin_name, configs)
project = plugin_cls(args.entity, configs=plugin_configs)
if project.process():
project_name = project_name or project.name()
branch_name = branch_name or project.branch()
break
if project_name is None:
project_name = args.alternate_project
return project_name, branch_name
def get_configs_for_plugin(plugin_name, configs):
if configs and configs.has_section(plugin_name):
return dict(configs.items(plugin_name))
return None

View File

@ -11,6 +11,8 @@
import logging
from ..exceptions import NotYetImplemented
log = logging.getLogger('WakaTime')
@ -25,29 +27,19 @@ class BaseProject(object):
self.path = path
self._configs = configs
def project_type(self):
""" Returns None if this is the base class.
Returns the type of project if this is a
valid project.
"""
project_type = self.__class__.__name__.lower()
if project_type == 'baseproject':
project_type = None
return project_type
def process(self):
""" Processes self.path into a project and
returns True if project is valid, otherwise
returns False.
"""
return False
raise NotYetImplemented()
def name(self):
""" Returns the project's name.
"""
return None
raise NotYetImplemented()
def branch(self):
""" Returns the current branch.
"""
return None
raise NotYetImplemented()

View File

@ -11,6 +11,7 @@
import logging
import os
import sys
from .base import BaseProject
from ..compat import u, open
@ -29,7 +30,7 @@ class Git(BaseProject):
base = self._project_base()
if base:
return u(os.path.basename(base))
return None
return None # pragma: nocover
def branch(self):
base = self._project_base()
@ -38,8 +39,14 @@ class Git(BaseProject):
try:
with open(head, 'r', encoding='utf-8') as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except IOError:
pass
except UnicodeDecodeError: # pragma: nocover
try:
with open(head, 'r', encoding=sys.getfilesystemencoding()) as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except:
log.traceback('warn')
except IOError: # pragma: nocover
log.traceback('warn')
return None
def _project_base(self):

View File

@ -11,6 +11,7 @@
import logging
import os
import sys
from .base import BaseProject
from ..compat import u, open
@ -28,7 +29,7 @@ class Mercurial(BaseProject):
def name(self):
if self.configDir:
return u(os.path.basename(os.path.dirname(self.configDir)))
return None
return None # pragma: nocover
def branch(self):
if self.configDir:
@ -36,8 +37,14 @@ class Mercurial(BaseProject):
try:
with open(branch_file, 'r', encoding='utf-8') as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except IOError:
pass
except UnicodeDecodeError: # pragma: nocover
try:
with open(branch_file, 'r', encoding=sys.getfilesystemencoding()) as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except:
log.traceback('warn')
except IOError: # pragma: nocover
log.traceback('warn')
return u('default')
def _find_hg_config_dir(self, path):

View File

@ -47,14 +47,14 @@ class ProjectMap(BaseProject):
if self._configs.get(path.lower()):
return self._configs.get(path.lower())
if self._configs.get('%s/' % path.lower()):
if self._configs.get('%s/' % path.lower()): # pragma: nocover
return self._configs.get('%s/' % path.lower())
if self._configs.get('%s\\' % path.lower()):
if self._configs.get('%s\\' % path.lower()): # pragma: nocover
return self._configs.get('%s\\' % path.lower())
split_path = os.path.split(path)
if split_path[1] == '':
return None
return None # pragma: nocover
return self._find_project(split_path[0])
def branch(self):
@ -63,4 +63,4 @@ class ProjectMap(BaseProject):
def name(self):
if self.project:
return u(self.project)
return None
return None # pragma: nocover

View File

@ -18,8 +18,8 @@ from .base import BaseProject
from ..compat import u, open
try:
from collections import OrderedDict
except ImportError:
from ..packages.ordereddict import OrderedDict
except ImportError: # pragma: nocover
from ..packages.ordereddict import OrderedDict # pragma: nocover
log = logging.getLogger('WakaTime')
@ -32,10 +32,14 @@ class Subversion(BaseProject):
return self._find_project_base(self.path)
def name(self):
return u(self.info['Repository Root'].split('/')[-1])
if 'Repository Root' not in self.info:
return None # pragma: nocover
return u(self.info['Repository Root'].split('/')[-1].split('\\')[-1])
def branch(self):
return u(self.info['URL'].split('/')[-1])
if 'URL' not in self.info:
return None # pragma: nocover
return u(self.info['URL'].split('/')[-1].split('\\')[-1])
def _find_binary(self):
if self.binary_location:
@ -46,13 +50,13 @@ class Subversion(BaseProject):
'/usr/local/bin/svn',
]
for location in locations:
with open(os.devnull, 'wb') as DEVNULL:
try:
try:
with open(os.devnull, 'wb') as DEVNULL:
Popen([location, '--version'], stdout=DEVNULL, stderr=DEVNULL)
self.binary_location = location
return location
except:
pass
except:
pass
self.binary_location = 'svn'
return 'svn'
@ -69,8 +73,7 @@ class Subversion(BaseProject):
else:
if stdout:
for line in stdout.splitlines():
if isinstance(line, bytes):
line = bytes.decode(line)
line = u(line)
line = line.split(': ', 1)
if len(line) == 2:
info[line[0]] = line[1]
@ -78,7 +81,7 @@ class Subversion(BaseProject):
def _find_project_base(self, path, found=False):
if platform.system() == 'Windows':
return False
return False # pragma: nocover
path = os.path.realpath(path)
if os.path.isfile(path):
path = os.path.split(path)[0]

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
wakatime.projects.wakatime
~~~~~~~~~~~~~~~~~~~~~~~~~~
wakatime.projects.wakatime_project_file
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Information from a .wakatime-project file about the project for
a given file. First line of .wakatime-project sets the project
@ -13,6 +13,7 @@
import logging
import os
import sys
from .base import BaseProject
from ..compat import u, open
@ -21,7 +22,7 @@ from ..compat import u, open
log = logging.getLogger('WakaTime')
class WakaTime(BaseProject):
class WakaTimeProjectFile(BaseProject):
def process(self):
self.config = self._find_config(self.path)
@ -34,8 +35,15 @@ class WakaTime(BaseProject):
with open(self.config, 'r', encoding='utf-8') as fh:
self._project_name = u(fh.readline().strip())
self._project_branch = u(fh.readline().strip())
except IOError:
log.exception("Exception:")
except UnicodeDecodeError: # pragma: nocover
try:
with open(self.config, 'r', encoding=sys.getfilesystemencoding()) as fh:
self._project_name = u(fh.readline().strip())
self._project_branch = u(fh.readline().strip())
except:
log.traceback('warn')
except IOError: # pragma: nocover
log.traceback('warn')
return True
return False

View File

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
"""
wakatime.session_cache
~~~~~~~~~~~~~~~~~~~~~~
Persist requests.Session for multiprocess SSL handshake pooling.
:copyright: (c) 2015 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
import logging
import os
import pickle
import sys
try:
import sqlite3
HAS_SQL = True
except ImportError: # pragma: nocover
HAS_SQL = False
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages'))
from .packages import requests
log = logging.getLogger('WakaTime')
class SessionCache(object):
DB_FILE = os.path.join(os.path.expanduser('~'), '.wakatime.db')
def connect(self):
conn = sqlite3.connect(self.DB_FILE)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS session (
value BLOB)
''')
return (conn, c)
def save(self, session):
"""Saves a requests.Session object for the next heartbeat process.
"""
if not HAS_SQL: # pragma: nocover
return
try:
conn, c = self.connect()
c.execute('DELETE FROM session')
values = {
'value': sqlite3.Binary(pickle.dumps(session, protocol=2)),
}
c.execute('INSERT INTO session VALUES (:value)', values)
conn.commit()
conn.close()
except: # pragma: nocover
log.traceback('debug')
def get(self):
"""Returns a requests.Session object.
Gets Session from sqlite3 cache or creates a new Session.
"""
if not HAS_SQL: # pragma: nocover
return requests.session()
try:
conn, c = self.connect()
except:
log.traceback('debug')
return requests.session()
session = None
try:
c.execute('BEGIN IMMEDIATE')
c.execute('SELECT value FROM session LIMIT 1')
row = c.fetchone()
if row is not None:
session = pickle.loads(row[0])
except: # pragma: nocover
log.traceback('debug')
try:
conn.close()
except: # pragma: nocover
log.traceback('debug')
return session if session is not None else requests.session()
def delete(self):
"""Clears all cached Session objects.
"""
if not HAS_SQL: # pragma: nocover
return
try:
conn, c = self.connect()
c.execute('DELETE FROM session')
conn.commit()
conn.close()
except:
log.traceback('debug')

View File

@ -14,84 +14,155 @@ import os
import sys
from .compat import u, open
from .languages import DependencyParser
from .dependencies import DependencyParser
if sys.version_info[0] == 2:
if sys.version_info[0] == 2: # pragma: nocover
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages', 'pygments_py2'))
else:
else: # pragma: nocover
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages', 'pygments_py3'))
from pygments.lexers import guess_lexer_for_filename
from pygments.lexers import get_lexer_by_name, guess_lexer_for_filename
from pygments.modeline import get_filetype_from_buffer
from pygments.util import ClassNotFound
log = logging.getLogger('WakaTime')
# force file name extensions to be recognized as a certain language
EXTENSIONS = {
'j2': 'HTML',
'markdown': 'Markdown',
'md': 'Markdown',
'mdown': 'Markdown',
'twig': 'Twig',
}
TRANSLATIONS = {
'CSS+Genshi Text': 'CSS',
'CSS+Lasso': 'CSS',
'HTML+Django/Jinja': 'HTML',
'HTML+Lasso': 'HTML',
'JavaScript+Genshi Text': 'JavaScript',
'JavaScript+Lasso': 'JavaScript',
'Perl6': 'Perl',
'RHTML': 'HTML',
}
def guess_language(file_name):
language, lexer = None, None
try:
with open(file_name, 'r', encoding='utf-8') as fh:
lexer = guess_lexer_for_filename(file_name, fh.read(512000))
except:
pass
if file_name:
language = guess_language_from_extension(file_name.rsplit('.', 1)[-1])
if lexer and language is None:
language = translate_language(u(lexer.name))
"""Guess lexer and language for a file.
Returns (language, lexer) tuple where language is a unicode string.
"""
language = get_language_from_extension(file_name)
lexer = smart_guess_lexer(file_name)
if language is None and lexer is not None:
language = u(lexer.name)
return language, lexer
def guess_language_from_extension(extension):
if extension:
if extension in EXTENSIONS:
return EXTENSIONS[extension]
if extension.lower() in EXTENSIONS:
return EXTENSIONS[extension.lower()]
def smart_guess_lexer(file_name):
"""Guess Pygments lexer for a file.
Looks for a vim modeline in file contents, then compares the accuracy
of that lexer with a second guess. The second guess looks up all lexers
matching the file name, then runs a text analysis for the best choice.
"""
lexer = None
text = get_file_contents(file_name)
lexer1, accuracy1 = guess_lexer_using_filename(file_name, text)
lexer2, accuracy2 = guess_lexer_using_modeline(text)
if lexer1:
lexer = lexer1
if (lexer2 and accuracy2 and
(not accuracy1 or accuracy2 > accuracy1)):
lexer = lexer2 # pragma: nocover
return lexer
def guess_lexer_using_filename(file_name, text):
"""Guess lexer for given text, limited to lexers for this file's extension.
Returns a tuple of (lexer, accuracy).
"""
lexer, accuracy = None, None
try:
lexer = guess_lexer_for_filename(file_name, text)
except: # pragma: nocover
pass
if lexer is not None:
try:
accuracy = lexer.analyse_text(text)
except: # pragma: nocover
pass
return lexer, accuracy
def guess_lexer_using_modeline(text):
"""Guess lexer for given text using Vim modeline.
Returns a tuple of (lexer, accuracy).
"""
lexer, accuracy = None, None
file_type = None
try:
file_type = get_filetype_from_buffer(text)
except: # pragma: nocover
pass
if file_type is not None:
try:
lexer = get_lexer_by_name(file_type)
except ClassNotFound: # pragma: nocover
pass
if lexer is not None:
try:
accuracy = lexer.analyse_text(text)
except: # pragma: nocover
pass
return lexer, accuracy
def get_language_from_extension(file_name):
"""Returns a matching language for the given file extension.
"""
filepart, extension = os.path.splitext(file_name)
if os.path.exists(u('{0}{1}').format(u(filepart), u('.c'))) or os.path.exists(u('{0}{1}').format(u(filepart), u('.C'))):
return 'C'
extension = extension.lower()
if extension == '.h':
directory = os.path.dirname(file_name)
available_files = os.listdir(directory)
available_extensions = list(zip(*map(os.path.splitext, available_files)))[1]
available_extensions = [ext.lower() for ext in available_extensions]
if '.cpp' in available_extensions:
return 'C++'
if '.c' in available_extensions:
return 'C'
return None
def translate_language(language):
if language in TRANSLATIONS:
language = TRANSLATIONS[language]
return language
def number_lines_in_file(file_name):
lines = 0
try:
with open(file_name, 'r', encoding='utf-8') as fh:
for line in fh:
lines += 1
except:
return None
except: # pragma: nocover
try:
with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh:
for line in fh:
lines += 1
except:
return None
return lines
def get_file_stats(file_name, notfile=False):
if notfile:
def get_file_stats(file_name, entity_type='file', lineno=None, cursorpos=None):
if entity_type != 'file':
stats = {
'language': None,
'dependencies': [],
'lines': None,
'lineno': lineno,
'cursorpos': cursorpos,
}
else:
language, lexer = guess_language(file_name)
@ -101,5 +172,24 @@ def get_file_stats(file_name, notfile=False):
'language': language,
'dependencies': dependencies,
'lines': number_lines_in_file(file_name),
'lineno': lineno,
'cursorpos': cursorpos,
}
return stats
def get_file_contents(file_name):
"""Returns the first 512000 bytes of the file's contents.
"""
text = None
try:
with open(file_name, 'r', encoding='utf-8') as fh:
text = fh.read(512000)
except: # pragma: nocover
try:
with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh:
text = fh.read(512000)
except:
log.traceback('debug')
return text