diff --git a/maloja/database.py b/maloja/database.py index b7a07e7..c0a1298 100644 --- a/maloja/database.py +++ b/maloja/database.py @@ -703,7 +703,7 @@ DB['scrobbles'] = sql.Table( sql.Column('rawscrobble',sql.String), sql.Column('origin',sql.String), sql.Column('duration',sql.Integer), - sql.Column('track_id',sql.Integer) + sql.Column('track_id',sql.Integer,sql.ForeignKey('tracks.id')) ) DB['tracks'] = sql.Table( 'tracks', meta, @@ -720,8 +720,8 @@ DB['artists'] = sql.Table( DB['trackartists'] = sql.Table( 'trackartists', meta, sql.Column('id',sql.Integer,primary_key=True), - sql.Column('artist_id',sql.Integer), - sql.Column('track_id',sql.Integer) + sql.Column('artist_id',sql.Integer,sql.ForeignKey('artists.id')), + sql.Column('track_id',sql.Integer,sql.ForeignKey('tracks.id')) ) meta.create_all(engine) @@ -792,10 +792,22 @@ def get_track_id(trackdict): ).where( DB['tracks'].c.title_normalized==ntitle ) - result = conn.execute(op) - for row in result: + result = conn.execute(op).all() + for row in result: + # check if the artists are the same + foundtrackartists = [] + with engine.begin() as conn: + op = DB['trackartists'].select( + DB['trackartists'].c.artist_id + ).where( + DB['trackartists'].c.track_id==row[0] + ) + result = conn.execute(op).all() + match_artist_ids = [r.artist_id for r in result] + print("required artists",artist_ids,"this match",match_artist_ids) + if set(artist_ids) == set(match_artist_ids): print("ID for",trackdict['title'],"was",row[0]) - return row[0] + return row.id with engine.begin() as conn: op = DB['tracks'].insert().values( @@ -803,8 +815,16 @@ def get_track_id(trackdict): title_normalized=ntitle ) result = conn.execute(op) - print("Created",trackdict['title'],result.inserted_primary_key) - return result.inserted_primary_key[0] + track_id = result.inserted_primary_key[0] + with engine.begin() as conn: + for artist_id in artist_ids: + op = DB['trackartists'].insert().values( + track_id=track_id, + artist_id=artist_id + ) + result = conn.execute(op) + print("Created",trackdict['title'],track_id) + return track_id def get_artist_id(artistname): nname = normalize_name(artistname) @@ -816,10 +836,10 @@ def get_artist_id(artistname): ).where( DB['artists'].c.name_normalized==nname ) - result = conn.execute(op) - for row in result: - print("ID for",artistname,"was",row[0]) - return row[0] + result = conn.execute(op).all() + for row in result: + print("ID for",artistname,"was",row[0]) + return row.id with engine.begin() as conn: op = DB['artists'].insert().values(