1
0
mirror of https://git.ikl.sh/132ikl/liteshort.git synced 2023-08-10 21:13:04 +03:00

Run Black formatting, isort, autoflake

This commit is contained in:
132ikl 2020-01-16 14:35:43 -05:00
parent aad08830f0
commit 1fe86f4e60
2 changed files with 191 additions and 118 deletions

View File

@ -3,67 +3,109 @@
# This software is license under the MIT license. It should be included in your copy of this software. # This software is license under the MIT license. It should be included in your copy of this software.
# A copy of the MIT license can be obtained at https://mit-license.org/ # A copy of the MIT license can be obtained at https://mit-license.org/
from flask import Flask, current_app, flash, g, jsonify, make_response, redirect, render_template, request, send_from_directory, url_for
import bcrypt
import os import os
import random import random
import sqlite3 import sqlite3
import time import time
import urllib import urllib
import bcrypt
import yaml import yaml
from flask import (Flask, current_app, flash, g, jsonify, make_response,
redirect, render_template, request, send_from_directory,
url_for)
app = Flask(__name__) app = Flask(__name__)
def load_config(): def load_config():
new_config = yaml.load(open('config.yml')) new_config = yaml.load(open("config.yml"))
new_config = {k.lower(): v for k, v in new_config.items()} # Make config keys case insensitive new_config = {
k.lower(): v for k, v in new_config.items()
} # Make config keys case insensitive
req_options = {'admin_username': 'admin', 'database_name': "urls", 'random_length': 4, req_options = {
'allowed_chars': 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_', "admin_username": "admin",
'random_gen_timeout': 5, 'site_name': 'liteshort', 'site_domain': None, 'show_github_link': True, "database_name": "urls",
'secret_key': None, 'disable_api': False, 'subdomain': '', 'latest': 'l', 'selflinks': False "random_length": 4,
} "allowed_chars": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
"random_gen_timeout": 5,
"site_name": "liteshort",
"site_domain": None,
"show_github_link": True,
"secret_key": None,
"disable_api": False,
"subdomain": "",
"latest": "l",
"selflinks": False,
}
config_types = {'admin_username': str, 'database_name': str, 'random_length': int, config_types = {
'allowed_chars': str, 'random_gen_timeout': int, 'site_name': str, "admin_username": str,
'site_domain': (str, type(None)), 'show_github_link': bool, 'secret_key': str, "database_name": str,
'disable_api': bool, 'subdomain': (str, type(None)), 'latest': (str, type(None)), 'selflinks': bool "random_length": int,
} "allowed_chars": str,
"random_gen_timeout": int,
"site_name": str,
"site_domain": (str, type(None)),
"show_github_link": bool,
"secret_key": str,
"disable_api": bool,
"subdomain": (str, type(None)),
"latest": (str, type(None)),
"selflinks": bool,
}
for option in req_options.keys(): for option in req_options.keys():
if option not in new_config.keys(): # Make sure everything in req_options is set in config if (
option not in new_config.keys()
): # Make sure everything in req_options is set in config
new_config[option] = req_options[option] new_config[option] = req_options[option]
for option in new_config.keys(): for option in new_config.keys():
if option in config_types: if option in config_types:
matches = False matches = False
if type(config_types[option]) is not tuple: if type(config_types[option]) is not tuple:
config_types[option] = (config_types[option],) # Automatically creates tuple for non-tuple types config_types[option] = (
for req_type in config_types[option]: # Iterates through tuple to allow multiple types for config options config_types[option],
) # Automatically creates tuple for non-tuple types
for req_type in config_types[
option
]: # Iterates through tuple to allow multiple types for config options
if type(new_config[option]) is req_type: if type(new_config[option]) is req_type:
matches = True matches = True
if not matches: if not matches:
raise TypeError(option + " is incorrect type") raise TypeError(option + " is incorrect type")
if not new_config['disable_api']: if not new_config["disable_api"]:
if 'admin_hashed_password' in new_config.keys() and new_config['admin_hashed_password']: if (
new_config['password_hashed'] = True "admin_hashed_password" in new_config.keys()
elif 'admin_password' in new_config.keys() and new_config['admin_password']: and new_config["admin_hashed_password"]
new_config['password_hashed'] = False ):
new_config["password_hashed"] = True
elif "admin_password" in new_config.keys() and new_config["admin_password"]:
new_config["password_hashed"] = False
else: else:
raise TypeError('admin_password or admin_hashed_password must be set in config.yml') raise TypeError(
"admin_password or admin_hashed_password must be set in config.yml"
)
return new_config return new_config
def authenticate(username, password): def authenticate(username, password):
return username == current_app.config['admin_username'] and check_password(password, current_app.config) return username == current_app.config["admin_username"] and check_password(
password, current_app.config
)
def check_long_exist(long): def check_long_exist(long):
query = query_db('SELECT short FROM urls WHERE long = ?', (long,)) query = query_db("SELECT short FROM urls WHERE long = ?", (long,))
for i in query: for i in query:
if i and (len(i['short']) <= current_app.config["random_length"]) and i['short'] != current_app.config['latest']: # Checks if query if pre-existing URL is same as random length URL if (
return i['short'] i
and (len(i["short"]) <= current_app.config["random_length"])
and i["short"] != current_app.config["latest"]
): # Checks if query if pre-existing URL is same as random length URL
return i["short"]
return False return False
@ -72,23 +114,30 @@ def check_short_exist(short): # Allow to also check against a long link
return True return True
return False return False
def check_self_link(long): def check_self_link(long):
if get_baseUrl().rstrip('/') in long: if get_baseUrl().rstrip("/") in long:
return True return True
return False return False
def check_password(password, pass_config): def check_password(password, pass_config):
if pass_config['password_hashed']: if pass_config["password_hashed"]:
return bcrypt.checkpw(password.encode('utf-8'), pass_config['admin_hashed_password'].encode('utf-8')) return bcrypt.checkpw(
elif not pass_config['password_hashed']: password.encode("utf-8"),
return password == pass_config['admin_password'] pass_config["admin_hashed_password"].encode("utf-8"),
)
elif not pass_config["password_hashed"]:
return password == pass_config["admin_password"]
else: else:
raise RuntimeError('This should never occur! Bailing...') raise RuntimeError("This should never occur! Bailing...")
def delete_url(deletion): def delete_url(deletion):
result = query_db('SELECT * FROM urls WHERE short = ?', (deletion,), False, None) # Return as tuple instead of row result = query_db(
get_db().cursor().execute('DELETE FROM urls WHERE short = ?', (deletion,)) "SELECT * FROM urls WHERE short = ?", (deletion,), False, None
) # Return as tuple instead of row
get_db().cursor().execute("DELETE FROM urls WHERE short = ?", (deletion,))
get_db().commit() get_db().commit()
return len(result) return len(result)
@ -101,33 +150,35 @@ def dict_factory(cursor, row):
def generate_short(rq): def generate_short(rq):
timeout = time.time() + current_app.config['random_gen_timeout'] timeout = time.time() + current_app.config["random_gen_timeout"]
while True: while True:
if time.time() >= timeout: if time.time() >= timeout:
return response(rq, None, 'Timeout while generating random short URL') return response(rq, None, "Timeout while generating random short URL")
short = ''.join(random.choice(current_app.config['allowed_chars']) short = "".join(
for i in range(current_app.config['random_length'])) random.choice(current_app.config["allowed_chars"])
if not check_short_exist(short) and short != app.config['latest']: for i in range(current_app.config["random_length"])
)
if not check_short_exist(short) and short != app.config["latest"]:
return short return short
def get_long(short): def get_long(short):
row = query_db('SELECT long FROM urls WHERE short = ?', (short,), True) row = query_db("SELECT long FROM urls WHERE short = ?", (short,), True)
if row and row['long']: if row and row["long"]:
return row['long'] return row["long"]
return None return None
def get_baseUrl(): def get_baseUrl():
if current_app.config['site_domain']: if current_app.config["site_domain"]:
# TODO: un-hack-ify adding the protocol here # TODO: un-hack-ify adding the protocol here
return 'https://' + current_app.config['site_domain'] + '/' return "https://" + current_app.config["site_domain"] + "/"
else: else:
return request.base_url return request.base_url
def list_shortlinks(): def list_shortlinks():
result = query_db('SELECT * FROM urls', (), False, None) result = query_db("SELECT * FROM urls", (), False, None)
result = nested_list_to_dict(result) result = nested_list_to_dict(result)
return result return result
@ -140,9 +191,9 @@ def nested_list_to_dict(l):
def response(rq, result, error_msg="Error: Unknown error"): def response(rq, result, error_msg="Error: Unknown error"):
if rq.form.get('api') and not rq.form.get('format') == 'json': if rq.form.get("api") and not rq.form.get("format") == "json":
return "Format type HTML (default) not support for API" # Future-proof for non-json return types return "Format type HTML (default) not support for API" # Future-proof for non-json return types
if rq.form.get('format') == 'json': if rq.form.get("format") == "json":
# If not result provided OR result doesn't exist, send error # If not result provided OR result doesn't exist, send error
# Allows for setting an error message with explicitly checking in regular code # Allows for setting an error message with explicitly checking in regular code
if result: if result:
@ -154,30 +205,40 @@ def response(rq, result, error_msg="Error: Unknown error"):
return jsonify(success=False, error=error_msg) return jsonify(success=False, error=error_msg)
else: else:
if result and result is not True: if result and result is not True:
flash(result, 'success') flash(result, "success")
elif not result: elif not result:
flash(error_msg, 'error') flash(error_msg, "error")
return render_template("main.html") return render_template("main.html")
def set_latest(long): def set_latest(long):
if app.config['latest']: if app.config["latest"]:
if query_db('SELECT short FROM urls WHERE short = ?', (current_app.config['latest'],)): if query_db(
get_db().cursor().execute("UPDATE urls SET long = ? WHERE short = ?", "SELECT short FROM urls WHERE short = ?", (current_app.config["latest"],)
(long, current_app.config['latest'])) ):
get_db().cursor().execute(
"UPDATE urls SET long = ? WHERE short = ?",
(long, current_app.config["latest"]),
)
else: else:
get_db().cursor().execute("INSERT INTO urls (long,short) VALUES (?, ?)", get_db().cursor().execute(
(long, current_app.config['latest'])) "INSERT INTO urls (long,short) VALUES (?, ?)",
(long, current_app.config["latest"]),
)
def validate_short(short): def validate_short(short):
if short == app.config['latest']: if short == app.config["latest"]:
return response(request, None, return response(
'Short URL cannot be the same as a special URL ({})'.format(short)) request,
None,
"Short URL cannot be the same as a special URL ({})".format(short),
)
for char in short: for char in short:
if char not in current_app.config['allowed_chars']: if char not in current_app.config["allowed_chars"]:
return response(request, None, return response(
'Character ' + char + ' not allowed in short URL') request, None, "Character " + char + " not allowed in short URL"
)
return True return True
@ -185,16 +246,17 @@ def validate_long(long): # https://stackoverflow.com/a/36283503
token = urllib.parse.urlparse(long) token = urllib.parse.urlparse(long)
return all([token.scheme, token.netloc]) return all([token.scheme, token.netloc])
# Database connection functions # Database connection functions
def get_db(): def get_db():
if 'db' not in g: if "db" not in g:
g.db = sqlite3.connect( g.db = sqlite3.connect(
''.join((current_app.config['database_name'], '.db')), "".join((current_app.config["database_name"], ".db")),
detect_types=sqlite3.PARSE_DECLTYPES detect_types=sqlite3.PARSE_DECLTYPES,
) )
g.db.cursor().execute('CREATE TABLE IF NOT EXISTS urls (long,short)') g.db.cursor().execute("CREATE TABLE IF NOT EXISTS urls (long,short)")
return g.db return g.db
@ -208,104 +270,116 @@ def query_db(query, args=(), one=False, row_factory=sqlite3.Row):
@app.teardown_appcontext @app.teardown_appcontext
def close_db(error): def close_db(error):
if hasattr(g, 'sqlite_db'): if hasattr(g, "sqlite_db"):
g.sqlite_db.close() g.sqlite_db.close()
app.config.update(load_config()) # Add YAML config to Flask config app.config.update(load_config()) # Add YAML config to Flask config
app.secret_key = app.config['secret_key'] app.secret_key = app.config["secret_key"]
app.config['SERVER_NAME'] = app.config['site_domain'] app.config["SERVER_NAME"] = app.config["site_domain"]
@app.route('/favicon.ico', subdomain=app.config['subdomain']) @app.route("/favicon.ico", subdomain=app.config["subdomain"])
def favicon(): def favicon():
return send_from_directory(os.path.join(app.root_path, 'static'), return send_from_directory(
'favicon.ico', mimetype='image/vnd.microsoft.icon') os.path.join(app.root_path, "static"),
"favicon.ico",
mimetype="image/vnd.microsoft.icon",
)
@app.route('/', subdomain=app.config['subdomain']) @app.route("/", subdomain=app.config["subdomain"])
def main(): def main():
return response(request, True) return response(request, True)
@app.route('/<url>') @app.route("/<url>")
def main_redir(url): def main_redir(url):
long = get_long(url) long = get_long(url)
if long: if long:
resp = make_response(redirect(long, 301)) resp = make_response(redirect(long, 301))
else: else:
flash('Short URL "' + url + '" doesn\'t exist', 'error') flash('Short URL "' + url + "\" doesn't exist", "error")
resp = make_response(redirect(url_for('main'))) resp = make_response(redirect(url_for("main")))
resp.headers.set('Cache-Control', 'no-store, must-revalidate') resp.headers.set("Cache-Control", "no-store, must-revalidate")
return resp return resp
@app.route('/', methods=['POST'], subdomain=app.config['subdomain']) @app.route("/", methods=["POST"], subdomain=app.config["subdomain"])
def main_post(): def main_post():
if request.form.get('long'): if request.form.get("long"):
if not validate_long(request.form['long']): if not validate_long(request.form["long"]):
return response(request, None, "Long URL is not valid") return response(request, None, "Long URL is not valid")
if request.form.get('short'): if request.form.get("short"):
# Validate long as URL and short custom text against allowed characters # Validate long as URL and short custom text against allowed characters
result = validate_short(request.form['short']) result = validate_short(request.form["short"])
if validate_short(request.form['short']) is True: if validate_short(request.form["short"]) is True:
short = request.form['short'] short = request.form["short"]
else: else:
return result return result
if get_long(short) == request.form['long']: if get_long(short) == request.form["long"]:
return response(request, get_baseUrl() + short, return response(
'Error: Failed to return pre-existing non-random shortlink') request,
get_baseUrl() + short,
"Error: Failed to return pre-existing non-random shortlink",
)
else: else:
short = generate_short(request) short = generate_short(request)
if check_short_exist(short): if check_short_exist(short):
return response(request, None, return response(request, None, "Short URL already taken")
'Short URL already taken') long_exists = check_long_exist(request.form["long"])
long_exists = check_long_exist(request.form['long']) if (
if check_self_link(request.form['long']) and not current_app.config['selflinks']: check_self_link(request.form["long"])
return response(request, None, and not current_app.config["selflinks"]
'You cannot link to this site') ):
if long_exists and not request.form.get('short'): return response(request, None, "You cannot link to this site")
set_latest(request.form['long']) if long_exists and not request.form.get("short"):
set_latest(request.form["long"])
get_db().commit() get_db().commit()
return response(request, get_baseUrl() + long_exists, return response(
'Error: Failed to return pre-existing random shortlink') request,
get_db().cursor().execute('INSERT INTO urls (long,short) VALUES (?,?)', (request.form['long'], short)) get_baseUrl() + long_exists,
set_latest(request.form['long']) "Error: Failed to return pre-existing random shortlink",
)
get_db().cursor().execute(
"INSERT INTO urls (long,short) VALUES (?,?)", (request.form["long"], short)
)
set_latest(request.form["long"])
get_db().commit() get_db().commit()
return response(request, get_baseUrl() + short, return response(request, get_baseUrl() + short, "Error: Failed to generate")
'Error: Failed to generate') elif request.form.get("api"):
elif request.form.get('api'): if current_app.config["disable_api"]:
if current_app.config['disable_api']:
return response(request, None, "API is disabled.") return response(request, None, "API is disabled.")
# All API calls require authentication # All API calls require authentication
if not request.authorization \ if not request.authorization or not authenticate(
or not authenticate(request.authorization['username'], request.authorization['password']): request.authorization["username"], request.authorization["password"]
):
return response(request, None, "BaiscAuth failed") return response(request, None, "BaiscAuth failed")
command = request.form['api'] command = request.form["api"]
if command == 'list' or command == 'listshort': if command == "list" or command == "listshort":
return response(request, list_shortlinks(), "Failed to list items") return response(request, list_shortlinks(), "Failed to list items")
elif command == 'listlong': elif command == "listlong":
shortlinks = list_shortlinks() shortlinks = list_shortlinks()
shortlinks = {v: k for k, v in shortlinks.items()} shortlinks = {v: k for k, v in shortlinks.items()}
return response(request, shortlinks, "Failed to list items") return response(request, shortlinks, "Failed to list items")
elif command == 'delete': elif command == "delete":
deleted = 0 deleted = 0
if 'long' not in request.form and 'short' not in request.form: if "long" not in request.form and "short" not in request.form:
return response(request, None, "Provide short or long in POST data") return response(request, None, "Provide short or long in POST data")
if 'short' in request.form: if "short" in request.form:
deleted = delete_url(request.form['short']) + deleted deleted = delete_url(request.form["short"]) + deleted
if 'long' in request.form: if "long" in request.form:
deleted = delete_url(request.form['long']) + deleted deleted = delete_url(request.form["long"]) + deleted
if deleted > 0: if deleted > 0:
return response(request, "Deleted " + str(deleted) + " URLs") return response(request, "Deleted " + str(deleted) + " URLs")
else: else:
return response(request, None, "Failed to delete URL") return response(request, None, "Failed to delete URL")
else: else:
return response(request, None, 'Command ' + command + ' not found') return response(request, None, "Command " + command + " not found")
else: else:
return response(request, None, 'Long URL required') return response(request, None, "Long URL required")
if __name__ == '__main__': if __name__ == "__main__":
app.run() app.run()

View File

@ -2,4 +2,3 @@ from liteshort import app
if __name__ == "__main__": if __name__ == "__main__":
app.run() app.run()