From c952fab44061d8b17c32b8b11065637984f21873 Mon Sep 17 00:00:00 2001 From: krateng Date: Sat, 26 Feb 2022 21:44:38 +0100 Subject: [PATCH] I have committed various war crimes and cannot enter heaven as a result --- maloja/database/sqldb.py | 333 +++++++++++++++++++++------------------ 1 file changed, 178 insertions(+), 155 deletions(-) diff --git a/maloja/database/sqldb.py b/maloja/database/sqldb.py index 2f9ecff..3cd77d8 100644 --- a/maloja/database/sqldb.py +++ b/maloja/database/sqldb.py @@ -58,6 +58,20 @@ DB['associated_artists'] = sql.Table( meta.create_all(engine) + +# decorator that passes either the provided dbconn, or creates a separate one +# just for this function call +def connection_provider(func): + + def wrapper(*args,**kwargs): + if kwargs.get("dbconn") is not None: + return func(*args,**kwargs) + else: + with engine.connect() as connection: + kwargs['dbconn'] = connection + return func(*args,**kwargs) + return wrapper + ##### DB <-> Dict translations ## ATTENTION ALL ADVENTURERS @@ -167,9 +181,11 @@ def artist_dict_to_db(info): ##### Actual Database interactions +@connection_provider def add_scrobble(scrobbledict,dbconn=None): - add_scrobbles([scrobbledict]) + add_scrobbles([scrobbledict],dbconn=dbconn) +@connection_provider def add_scrobbles(scrobbleslist,dbconn=None): ops = [ @@ -178,86 +194,86 @@ def add_scrobbles(scrobbleslist,dbconn=None): ) for s in scrobbleslist ] - with engine.begin() as conn: - for op in ops: - try: - conn.execute(op) - except sql.exc.IntegrityError: - pass + + for op in ops: + try: + dbconn.execute(op) + except sql.exc.IntegrityError: + pass ### these will 'get' the ID of an entity, creating it if necessary @cached_wrapper +@connection_provider def get_track_id(trackdict,dbconn=None): ntitle = normalize_name(trackdict['title']) artist_ids = [get_artist_id(a) for a in trackdict['artists']] - with engine.begin() as conn: - op = DB['tracks'].select( - DB['tracks'].c.id - ).where( - DB['tracks'].c.title_normalized==ntitle - ) - result = conn.execute(op).all() + + op = DB['tracks'].select( + DB['tracks'].c.id + ).where( + DB['tracks'].c.title_normalized==ntitle + ) + result = dbconn.execute(op).all() for row in result: # check if the artists are the same foundtrackartists = [] - with engine.begin() as conn: - op = DB['trackartists'].select( - DB['trackartists'].c.artist_id - ).where( - DB['trackartists'].c.track_id==row[0] - ) - result = conn.execute(op).all() + + op = DB['trackartists'].select( + DB['trackartists'].c.artist_id + ).where( + DB['trackartists'].c.track_id==row[0] + ) + result = dbconn.execute(op).all() match_artist_ids = [r.artist_id for r in result] #print("required artists",artist_ids,"this match",match_artist_ids) if set(artist_ids) == set(match_artist_ids): #print("ID for",trackdict['title'],"was",row[0]) return row.id - with engine.begin() as conn: - op = DB['tracks'].insert().values( - **track_dict_to_db(trackdict) + + op = DB['tracks'].insert().values( + **track_dict_to_db(trackdict) + ) + result = dbconn.execute(op) + track_id = result.inserted_primary_key[0] + + for artist_id in artist_ids: + op = DB['trackartists'].insert().values( + track_id=track_id, + artist_id=artist_id ) - result = conn.execute(op) - track_id = result.inserted_primary_key[0] - with engine.begin() as conn: - for artist_id in artist_ids: - op = DB['trackartists'].insert().values( - track_id=track_id, - artist_id=artist_id - ) - result = conn.execute(op) - #print("Created",trackdict['title'],track_id) - return track_id + result = dbconn.execute(op) + #print("Created",trackdict['title'],track_id) + return track_id @cached_wrapper +@connection_provider def get_artist_id(artistname,dbconn=None): nname = normalize_name(artistname) #print("looking for",nname) - with engine.begin() as conn: - op = DB['artists'].select( - DB['artists'].c.id - ).where( - DB['artists'].c.name_normalized==nname - ) - result = conn.execute(op).all() + op = DB['artists'].select( + DB['artists'].c.id + ).where( + DB['artists'].c.name_normalized==nname + ) + result = dbconn.execute(op).all() for row in result: #print("ID for",artistname,"was",row[0]) return row.id - with engine.begin() as conn: - op = DB['artists'].insert().values( - name=artistname, - name_normalized=nname - ) - result = conn.execute(op) - #print("Created",artistname,result.inserted_primary_key) - return result.inserted_primary_key[0] + op = DB['artists'].insert().values( + name=artistname, + name_normalized=nname + ) + result = dbconn.execute(op) + #print("Created",artistname,result.inserted_primary_key) + return result.inserted_primary_key[0] @@ -266,6 +282,7 @@ def get_artist_id(artistname,dbconn=None): ### Functions that get rows according to parameters @cached_wrapper +@connection_provider def get_scrobbles_of_artist(artist,since=None,to=None,dbconn=None): if since is None: since=0 @@ -274,19 +291,20 @@ def get_scrobbles_of_artist(artist,since=None,to=None,dbconn=None): artist_id = get_artist_id(artist) jointable = sql.join(DB['scrobbles'],DB['trackartists'],DB['scrobbles'].c.track_id == DB['trackartists'].c.track_id) - with engine.begin() as conn: - op = jointable.select().where( - DB['scrobbles'].c.timestamp<=to, - DB['scrobbles'].c.timestamp>=since, - DB['trackartists'].c.artist_id==artist_id - ).order_by(sql.asc('timestamp')) - result = conn.execute(op).all() + + op = jointable.select().where( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since, + DB['trackartists'].c.artist_id==artist_id + ).order_by(sql.asc('timestamp')) + result = dbconn.execute(op).all() result = scrobbles_db_to_dict(result) #result = [scrobble_db_to_dict(row,resolve_references=resolve_references) for row in result] return result @cached_wrapper +@connection_provider def get_scrobbles_of_track(track,since=None,to=None,dbconn=None): if since is None: since=0 @@ -294,79 +312,82 @@ def get_scrobbles_of_track(track,since=None,to=None,dbconn=None): track_id = get_track_id(track) - with engine.begin() as conn: - op = DB['scrobbles'].select().where( - DB['scrobbles'].c.timestamp<=to, - DB['scrobbles'].c.timestamp>=since, - DB['scrobbles'].c.track_id==track_id - ).order_by(sql.asc('timestamp')) - result = conn.execute(op).all() + op = DB['scrobbles'].select().where( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since, + DB['scrobbles'].c.track_id==track_id + ).order_by(sql.asc('timestamp')) + result = dbconn.execute(op).all() result = scrobbles_db_to_dict(result) #result = [scrobble_db_to_dict(row) for row in result] return result @cached_wrapper +@connection_provider def get_scrobbles(since=None,to=None,resolve_references=True,dbconn=None): if since is None: since=0 if to is None: to=now() - with engine.begin() as conn: - op = DB['scrobbles'].select().where( - DB['scrobbles'].c.timestamp<=to, - DB['scrobbles'].c.timestamp>=since, - ).order_by(sql.asc('timestamp')) - result = conn.execute(op).all() + op = DB['scrobbles'].select().where( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since, + ).order_by(sql.asc('timestamp')) + result = dbconn.execute(op).all() result = scrobbles_db_to_dict(result) #result = [scrobble_db_to_dict(row,resolve_references=resolve_references) for i,row in enumerate(result) if i=since - ).group_by( - sql.func.coalesce(DB['associated_artists'].c.target_artist,DB['trackartists'].c.artist_id) - ).order_by(sql.desc('count')) - result = conn.execute(op).all() + op = sql.select( + sql.func.count(sql.func.distinct(DB['scrobbles'].c.timestamp)).label('count'), + # only count distinct scrobbles - because of artist replacement, we could end up + # with two artists of the same scrobble counting it twice for the same artist + # e.g. Irene and Seulgi adding two scrobbles to Red Velvet for one real scrobble + sql.func.coalesce(DB['associated_artists'].c.target_artist,DB['trackartists'].c.artist_id).label('artist_id') + # use the replaced artist as artist to count if it exists, otherwise original one + ).select_from(jointable2).where( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since + ).group_by( + sql.func.coalesce(DB['associated_artists'].c.target_artist,DB['trackartists'].c.artist_id) + ).order_by(sql.desc('count')) + result = dbconn.execute(op).all() counts = [row.count for row in result] @@ -404,17 +424,18 @@ def count_scrobbles_by_artist(since,to,dbconn=None): return result @cached_wrapper +@connection_provider def count_scrobbles_by_track(since,to,dbconn=None): - with engine.begin() as conn: - op = sql.select( - sql.func.count(sql.func.distinct(DB['scrobbles'].c.timestamp)).label('count'), - DB['scrobbles'].c.track_id - ).select_from(DB['scrobbles']).where( - DB['scrobbles'].c.timestamp<=to, - DB['scrobbles'].c.timestamp>=since - ).group_by(DB['scrobbles'].c.track_id).order_by(sql.desc('count')) - result = conn.execute(op).all() + + op = sql.select( + sql.func.count(sql.func.distinct(DB['scrobbles'].c.timestamp)).label('count'), + DB['scrobbles'].c.track_id + ).select_from(DB['scrobbles']).where( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since + ).group_by(DB['scrobbles'].c.track_id).order_by(sql.desc('count')) + result = dbconn.execute(op).all() counts = [row.count for row in result] @@ -424,6 +445,7 @@ def count_scrobbles_by_track(since,to,dbconn=None): return result @cached_wrapper +@connection_provider def count_scrobbles_by_track_of_artist(since,to,artist,dbconn=None): artist_id = get_artist_id(artist) @@ -434,16 +456,15 @@ def count_scrobbles_by_track_of_artist(since,to,artist,dbconn=None): DB['scrobbles'].c.track_id == DB['trackartists'].c.track_id ) - with engine.begin() as conn: - op = sql.select( - sql.func.count(sql.func.distinct(DB['scrobbles'].c.timestamp)).label('count'), - DB['scrobbles'].c.track_id - ).select_from(jointable).filter( - DB['scrobbles'].c.timestamp<=to, - DB['scrobbles'].c.timestamp>=since, - DB['trackartists'].c.artist_id==artist_id - ).group_by(DB['scrobbles'].c.track_id).order_by(sql.desc('count')) - result = conn.execute(op).all() + op = sql.select( + sql.func.count(sql.func.distinct(DB['scrobbles'].c.timestamp)).label('count'), + DB['scrobbles'].c.track_id + ).select_from(jointable).filter( + DB['scrobbles'].c.timestamp<=to, + DB['scrobbles'].c.timestamp>=since, + DB['trackartists'].c.artist_id==artist_id + ).group_by(DB['scrobbles'].c.track_id).order_by(sql.desc('count')) + result = dbconn.execute(op).all() counts = [row.count for row in result] @@ -458,12 +479,12 @@ def count_scrobbles_by_track_of_artist(since,to,artist,dbconn=None): ### functions that get mappings for several entities -> rows @cached_wrapper +@connection_provider def get_artists_of_tracks(track_ids,dbconn=None): - with engine.begin() as conn: - op = sql.join(DB['trackartists'],DB['artists']).select().where( - DB['trackartists'].c.track_id.in_(track_ids) - ) - result = conn.execute(op).all() + op = sql.join(DB['trackartists'],DB['artists']).select().where( + DB['trackartists'].c.track_id.in_(track_ids) + ) + result = dbconn.execute(op).all() artists = {} for row in result: @@ -472,12 +493,12 @@ def get_artists_of_tracks(track_ids,dbconn=None): @cached_wrapper +@connection_provider def get_tracks_map(track_ids,dbconn=None): - with engine.begin() as conn: - op = DB['tracks'].select().where( - DB['tracks'].c.id.in_(track_ids) - ) - result = conn.execute(op).all() + op = DB['tracks'].select().where( + DB['tracks'].c.id.in_(track_ids) + ) + result = dbconn.execute(op).all() tracks = {} trackids = [row.id for row in result] @@ -487,12 +508,13 @@ def get_tracks_map(track_ids,dbconn=None): return tracks @cached_wrapper +@connection_provider def get_artists_map(artist_ids,dbconn=None): - with engine.begin() as conn: - op = DB['artists'].select().where( - DB['artists'].c.id.in_(artist_ids) - ) - result = conn.execute(op).all() + + op = DB['artists'].select().where( + DB['artists'].c.id.in_(artist_ids) + ) + result = dbconn.execute(op).all() artists = {} artistids = [row.id for row in result] @@ -505,6 +527,7 @@ def get_artists_map(artist_ids,dbconn=None): ### associations @cached_wrapper +@connection_provider def get_associated_artists(dbconn=None,*artists): artist_ids = [get_artist_id(a) for a in artists] @@ -514,16 +537,16 @@ def get_associated_artists(dbconn=None,*artists): DB['associated_artists'].c.source_artist == DB['artists'].c.id ) - with engine.begin() as conn: - op = jointable.select().where( - DB['associated_artists'].c.target_artist.in_(artist_ids) - ) - result = conn.execute(op).all() + op = jointable.select().where( + DB['associated_artists'].c.target_artist.in_(artist_ids) + ) + result = dbconn.execute(op).all() artists = artists_db_to_dict(result) return artists @cached_wrapper +@connection_provider def get_credited_artists(dbconn=None,*artists): artist_ids = [get_artist_id(a) for a in artists] @@ -533,11 +556,11 @@ def get_credited_artists(dbconn=None,*artists): DB['associated_artists'].c.target_artist == DB['artists'].c.id ) - with engine.begin() as conn: - op = jointable.select().where( - DB['associated_artists'].c.source_artist.in_(artist_ids) - ) - result = conn.execute(op).all() + + op = jointable.select().where( + DB['associated_artists'].c.source_artist.in_(artist_ids) + ) + result = dbconn.execute(op).all() artists = artists_db_to_dict(result) return artists @@ -546,23 +569,23 @@ def get_credited_artists(dbconn=None,*artists): ### get a specific entity by id @cached_wrapper +@connection_provider def get_track(id,dbconn=None): - with engine.begin() as conn: - op = DB['tracks'].select().where( - DB['tracks'].c.id==id - ) - result = conn.execute(op).all() + op = DB['tracks'].select().where( + DB['tracks'].c.id==id + ) + result = dbconn.execute(op).all() trackinfo = result[0] return track_db_to_dict(trackinfo) @cached_wrapper +@connection_provider def get_artist(id,dbconn=None): - with engine.begin() as conn: - op = DB['artists'].select().where( - DB['artists'].c.id==id - ) - result = conn.execute(op).all() + op = DB['artists'].select().where( + DB['artists'].c.id==id + ) + result = dbconn.execute(op).all() artistinfo = result[0] return artist_db_to_dict(artistinfo)